From 80dd71c5f519c59e60f13cb75eafc8e0e6a487f7 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 26 May 2026 05:07:35 +0000 Subject: [PATCH 01/18] Add param/ret/sig refinement-type annotations via thrust_macros MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce `thrust_macros::param`, `thrust_macros::ret`, and `thrust_macros::sig` attribute macros that lower refinement types (e.g. `{ v: i32 | v > 0 }`) into `#[thrust::formula_fn]`s, giving refinement formulas the same rustc-typechecked treatment as requires/ensures. Each refinement is placed via a new "type position" — a path addressing into the function type (parameter index or return slot, then generic-argument indices for enum args and Box pointees) — emitted as a `#[thrust::refine(..)]` path statement and installed into the parameter or return RefinedType template. Migrate the existing `thrust::sig` tests to `thrust_macros::sig`. --- src/analyze.rs | 79 ++++++ src/analyze/annot.rs | 31 +++ src/analyze/local_def.rs | 7 + src/refine/template.rs | 33 +++ src/rty.rs | 40 +++ tests/ui/fail/annot_box_term.rs | 2 +- tests/ui/fail/refine_param_simple.rs | 9 + tests/ui/fail/refine_sig.rs | 8 + tests/ui/pass/annot_box_term.rs | 2 +- tests/ui/pass/refine_param_simple.rs | 9 + tests/ui/pass/refine_sig.rs | 8 + thrust-macros/src/lib.rs | 397 ++++++++++++++++++++++++++- 12 files changed, 622 insertions(+), 3 deletions(-) create mode 100644 tests/ui/fail/refine_param_simple.rs create mode 100644 tests/ui/fail/refine_sig.rs create mode 100644 tests/ui/pass/refine_param_simple.rs create mode 100644 tests/ui/pass/refine_sig.rs diff --git a/src/analyze.rs b/src/analyze.rs index 45f47558..65df3712 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -790,6 +790,85 @@ impl<'tcx> Analyzer<'tcx> { ensure_annot } + /// Collects every `#[thrust::refine(..)]` path statement in the function + /// body, returning each `(type position, formula_fn DefId)`. + fn extract_refine_paths(&self, local_def_id: LocalDefId) -> Vec<(Vec, DefId)> { + let mut out = Vec::new(); + let Some(body) = self.tcx.hir_maybe_body_owned_by(local_def_id) else { + return out; + }; + let rustc_hir::ExprKind::Block(block, _) = body.value.kind else { + return out; + }; + let attr_path = analyze::annot::refine_path(); + let typeck = self.tcx.typeck(local_def_id); + for stmt in block.stmts { + let Some(attr) = self + .tcx + .hir_attrs(stmt.hir_id) + .iter() + .find(|attr| attr.path_matches(&attr_path)) + else { + continue; + }; + let ts = analyze::annot::extract_annot_tokens(attr.clone()); + let position = analyze::annot::parse_position(&ts); + + let rustc_hir::StmtKind::Semi(expr) = stmt.kind else { + self.tcx.dcx().span_err( + stmt.span, + "annotated path is expected to be a semi statement", + ); + continue; + }; + let rustc_hir::ExprKind::Path(qpath) = expr.kind else { + self.tcx.dcx().span_err( + expr.span, + "annotated path is expected to be a path expression", + ); + continue; + }; + let rustc_hir::def::Res::Def(_, def_id) = typeck.qpath_res(&qpath, expr.hir_id) else { + self.tcx.dcx().span_err( + expr.span, + "annotated path is expected to refer to a definition", + ); + continue; + }; + out.push((position, def_id)); + } + out + } + + /// Resolves every `#[thrust::refine(..)]` annotation into a positioned + /// refinement, by translating the referenced formula function. + pub fn extract_refine_annots( + &self, + local_def_id: LocalDefId, + generic_args: mir_ty::GenericArgsRef<'tcx>, + ) -> Vec<(Vec, rty::Refinement)> { + let mut out = Vec::new(); + for (position, def_id) in self.extract_refine_paths(local_def_id) { + let Some(formula_def_id) = def_id.as_local() else { + panic!( + "refine annotation with path is expected to refer to a local def, but found: {:?}", + def_id + ); + }; + let Some(formula_fn) = self.formula_fn_with_args(formula_def_id, generic_args) else { + panic!( + "refine annotation {:?} is not a formula function", + formula_def_id + ); + }; + let AnnotFormula::Formula(formula) = formula_fn.to_ensure_annot() else { + panic!("refine annotation must lower to a plain formula"); + }; + out.push((position, formula.into())); + } + out + } + /// Whether the given `def_id` corresponds to a method of one of the `Fn` traits. fn is_fn_trait_method(&self, def_id: DefId) -> bool { self.tcx diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index 6a40943e..b969fb56 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -61,6 +61,10 @@ pub fn ensures_path_path() -> [Symbol; 2] { [Symbol::intern("thrust"), Symbol::intern("ensures_path")] } +pub fn refine_path() -> [Symbol; 2] { + [Symbol::intern("thrust"), Symbol::intern("refine")] +} + pub fn model_ty_path() -> [Symbol; 3] { [ Symbol::intern("thrust"), @@ -215,6 +219,33 @@ pub fn extract_annot_tokens(attr: Attribute) -> TokenStream { d.tokens } +/// Parses a comma-separated list of integer literals (a "type position") from +/// the tokens of a `#[thrust::refine(..)]` attribute. +pub fn parse_position(ts: &TokenStream) -> Vec { + use rustc_ast::token::{LitKind, TokenKind}; + use rustc_ast::tokenstream::TokenTree; + + let mut out = Vec::new(); + for tt in ts.iter() { + match tt { + TokenTree::Token(t, _) => match &t.kind { + TokenKind::Comma => {} + TokenKind::Literal(lit) if lit.kind == LitKind::Integer => { + out.push( + lit.symbol + .as_str() + .parse() + .expect("invalid integer in refine position"), + ); + } + _ => panic!("unexpected token in refine position: {:?}", t), + }, + _ => panic!("unexpected token tree in refine position"), + } + } + out +} + pub fn split_param(ts: &TokenStream) -> (Ident, TokenStream) { use rustc_ast::token::TokenKind; use rustc_ast::tokenstream::TokenTree; diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index 64d4100d..5fd6f6ce 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -406,6 +406,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { assert!(require_annot.is_none() || param_annots.is_empty()); assert!(ensure_annot.is_none() || ret_annot.is_none()); + let refine_annots = self + .ctx + .extract_refine_annots(self.local_def_id, self.generic_args); + let trait_item_ty = self.trait_item_ty(); let is_fully_annotated = self.is_fully_annotated(); @@ -431,6 +435,9 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { if let Some(ret_rty) = ret_annot { builder.ret_rty(ret_rty); } + for (position, refinement) in refine_annots { + builder.refine(&position, refinement); + } if is_fully_annotated { rty::RefinedType::unrefined(builder.build().into()) diff --git a/src/refine/template.rs b/src/refine/template.rs index ed0762ed..074b8bb7 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -565,6 +565,39 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> { self.ret_rty = Some(rty); self } + + /// Installs a refinement at a function-type position. The first index of + /// `path` selects a parameter (by index) or the return slot (when it equals + /// the parameter count); the remaining indices descend into the slot's type. + pub fn refine( + &mut self, + path: &[usize], + refinement: rty::Refinement, + ) -> &mut Self { + let (&slot, sub) = path.split_first().expect("refine path must be non-empty"); + let n = self.param_tys.len(); + if slot < n { + let idx = rty::FunctionParamIdx::from(slot); + if !self.param_rtys.contains_key(&idx) { + let ty = self.inner.build(self.param_tys[slot].ty).vacuous(); + self.param_rtys.insert(idx, rty::RefinedType::unrefined(ty)); + } + self.param_rtys + .get_mut(&idx) + .unwrap() + .install_refinement_at(sub, refinement); + } else { + if self.ret_rty.is_none() { + let ty = self.inner.build(self.ret_ty).vacuous(); + self.ret_rty = Some(rty::RefinedType::unrefined(ty)); + } + self.ret_rty + .as_mut() + .unwrap() + .install_refinement_at(sub, refinement); + } + self + } } impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> diff --git a/src/rty.rs b/src/rty.rs index c9a3249a..a7c543c0 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -1475,6 +1475,46 @@ where } } +impl RefinedType { + /// Installs `refinement` at the given type position, descending through the + /// function type (parameters then return), enum type arguments, and `Box` + /// pointees. An empty path replaces the refinement at the current node. + pub fn install_refinement_at( + &mut self, + path: &[usize], + refinement: Refinement, + ) { + let Some((&step, rest)) = path.split_first() else { + self.refinement = refinement; + return; + }; + match &mut self.ty { + Type::Enum(e) => { + let arg = e.args.get_mut(TypeParamIdx::from(step)).unwrap_or_else(|| { + panic!("refine position {} out of range for enum type", step) + }); + arg.install_refinement_at(rest, refinement); + } + Type::Pointer(p) => { + assert_eq!(step, 0, "Box type position must be 0"); + p.elem.install_refinement_at(rest, refinement); + } + Type::Function(f) => { + let n = f.params.len(); + if step < n { + f.params[FunctionParamIdx::from(step)].install_refinement_at(rest, refinement); + } else { + f.ret.install_refinement_at(rest, refinement); + } + } + ty => panic!( + "unsupported type at refine position step {}: {:?}", + step, ty + ), + } + } +} + impl RefinedType { fn pretty_atom<'a, 'b, D>( &'b self, diff --git a/tests/ui/fail/annot_box_term.rs b/tests/ui/fail/annot_box_term.rs index 27184bc6..4391c35e 100644 --- a/tests/ui/fail/annot_box_term.rs +++ b/tests/ui/fail/annot_box_term.rs @@ -1,7 +1,7 @@ //@error-in-other-file: Unsat //@compile-flags: -C debug-assertions=off -#[thrust::sig(fn(x: int) -> {r: Box | r == })] +#[thrust_macros::sig(fn(x: i64) -> { r: Box | r == thrust_models::model::Box::new(x) })] fn box_create(x: i64) -> Box { Box::new(x) } diff --git a/tests/ui/fail/refine_param_simple.rs b/tests/ui/fail/refine_param_simple.rs new file mode 100644 index 00000000..11697be7 --- /dev/null +++ b/tests/ui/fail/refine_param_simple.rs @@ -0,0 +1,9 @@ +//@error-in-other-file: Unsat + +#[thrust_macros::param(x: { v: i32 | v > 0 })] +#[thrust_macros::ret({ r: i32 | r > x })] +fn f(x: i32) -> i32 { + x +} + +fn main() {} diff --git a/tests/ui/fail/refine_sig.rs b/tests/ui/fail/refine_sig.rs new file mode 100644 index 00000000..17b36888 --- /dev/null +++ b/tests/ui/fail/refine_sig.rs @@ -0,0 +1,8 @@ +//@error-in-other-file: Unsat + +#[thrust_macros::sig(fn(x: { v: i32 | v > 0 }) -> { r: i32 | r > x })] +fn g(x: i32) -> i32 { + x +} + +fn main() {} diff --git a/tests/ui/pass/annot_box_term.rs b/tests/ui/pass/annot_box_term.rs index 67c71e16..b96d49e8 100644 --- a/tests/ui/pass/annot_box_term.rs +++ b/tests/ui/pass/annot_box_term.rs @@ -1,7 +1,7 @@ //@check-pass //@compile-flags: -C debug-assertions=off -#[thrust::sig(fn(x: int) -> {r: Box | r == })] +#[thrust_macros::sig(fn(x: i64) -> { r: Box | r == thrust_models::model::Box::new(x) })] fn box_create(x: i64) -> Box { Box::new(x) } diff --git a/tests/ui/pass/refine_param_simple.rs b/tests/ui/pass/refine_param_simple.rs new file mode 100644 index 00000000..bd732583 --- /dev/null +++ b/tests/ui/pass/refine_param_simple.rs @@ -0,0 +1,9 @@ +//@check-pass + +#[thrust_macros::param(x: { v: i32 | v > 0 })] +#[thrust_macros::ret({ r: i32 | r >= x })] +fn f(x: i32) -> i32 { + x +} + +fn main() {} diff --git a/tests/ui/pass/refine_sig.rs b/tests/ui/pass/refine_sig.rs new file mode 100644 index 00000000..db233970 --- /dev/null +++ b/tests/ui/pass/refine_sig.rs @@ -0,0 +1,8 @@ +//@check-pass + +#[thrust_macros::sig(fn(x: { v: i32 | v > 0 }) -> { r: i32 | r >= x })] +fn g(x: i32) -> i32 { + x +} + +fn main() {} diff --git a/thrust-macros/src/lib.rs b/thrust-macros/src/lib.rs index 99df73dc..3afc8409 100644 --- a/thrust-macros/src/lib.rs +++ b/thrust-macros/src/lib.rs @@ -1,5 +1,5 @@ use proc_macro::TokenStream; -use proc_macro2::TokenStream as TokenStream2; +use proc_macro2::{TokenStream as TokenStream2, TokenTree as TokenTree2}; use quote::{format_ident, quote, ToTokens}; use syn::{ parse_macro_input, punctuated::Punctuated, FnArg, GenericParam, Generics, TypeParamBound, @@ -571,6 +571,401 @@ impl ExpandedTokens { } } +// --------------------------------------------------------------------------- +// Refinement-type annotations: `param`, `ret`, `sig`. +// +// These lower refinement types (e.g. `List<{ v: i32 | v > 0 }>`) into +// `#[thrust::formula_fn]`s plus positioned `#[thrust::refine(..)]` path +// statements injected into the function body. The "type position" addresses +// into the function type: the first index selects a parameter (by index) or +// the return slot (== parameter count), and subsequent indices descend into +// generic arguments (enum args / `Box` pointee). +// --------------------------------------------------------------------------- + +#[derive(Clone)] +struct Refinement { + path: Vec, + binder: syn::Ident, + binder_ty: TokenStream2, + formula: TokenStream2, +} + +#[proc_macro_attribute] +pub fn param(attr: TokenStream, item: TokenStream) -> TokenStream { + expand_refine(RefineKind::Param, attr, item) +} + +#[proc_macro_attribute] +pub fn ret(attr: TokenStream, item: TokenStream) -> TokenStream { + expand_refine(RefineKind::Ret, attr, item) +} + +#[proc_macro_attribute] +pub fn sig(attr: TokenStream, item: TokenStream) -> TokenStream { + expand_refine(RefineKind::Sig, attr, item) +} + +enum RefineKind { + Param, + Ret, + Sig, +} + +fn expand_refine(kind: RefineKind, attr: TokenStream, item: TokenStream) -> TokenStream { + let mut func = parse_macro_input!(item as FnItemWithSignature); + + let outer_context = match extract_outer_context(&func) { + Ok(ctx) => ctx, + Err(e) => { + let err = e.to_compile_error(); + return quote! { #err #func }.into(); + } + }; + + let attr_tokens: Vec = TokenStream2::from(attr).into_iter().collect(); + let jobs = match build_refine_jobs(kind, &func, &attr_tokens) { + Ok(jobs) => jobs, + Err(e) => { + let err = e.to_compile_error(); + return quote! { #err #func }.into(); + } + }; + + let mut refinements = Vec::new(); + for (root, ty_tokens) in jobs { + if let Err(e) = scan_type(&ty_tokens, &root, &mut refinements) { + let err = e.to_compile_error(); + return quote! { #err #func }.into(); + } + } + + if refinements.is_empty() { + return func.into_token_stream().into(); + } + + let has_receiver = func.sig().receiver().is_some(); + let mut formula_fns = Vec::new(); + let mut path_stmts = Vec::new(); + for mut r in refinements { + if has_receiver { + r.formula = rewrite_self_in_tokens(r.formula); + } + formula_fns.push(refine_formula_fn(&func, outer_context.as_ref(), &r)); + path_stmts.push(refine_path_stmt(&func, &r)); + } + + let Some(block) = func.block_mut() else { + let err = syn::Error::new_spanned( + func.sig().ident.clone(), + "refinement-type annotations require a function body", + ) + .into_compile_error(); + return quote! { #err #func }.into(); + }; + let orig_stmts = block.stmts.drain(..).collect::>(); + *block = syn::parse_quote!({ + #(#path_stmts)* + #(#orig_stmts)* + }); + func.attrs_mut() + .push(syn::parse_quote!(#[allow(path_statements)])); + + quote! { + #(#formula_fns)* + #func + } + .into() +} + +/// Builds `(root_path, type_tokens)` jobs to scan from the attribute tokens. +fn build_refine_jobs( + kind: RefineKind, + func: &FnItemWithSignature, + attr_tokens: &[TokenTree2], +) -> syn::Result, Vec)>> { + let param_count = func.sig().inputs.len(); + match kind { + RefineKind::Param => { + let (name, ty_tokens) = split_name_type(attr_tokens)?; + let idx = param_index(func, &name)?; + Ok(vec![(vec![idx], ty_tokens)]) + } + RefineKind::Ret => Ok(vec![(vec![param_count], attr_tokens.to_vec())]), + RefineKind::Sig => { + let (args, ret_tokens) = parse_sig_attr(attr_tokens)?; + let mut jobs = Vec::new(); + for (name, ty_tokens) in args { + let idx = param_index(func, &name)?; + jobs.push((vec![idx], ty_tokens)); + } + jobs.push((vec![param_count], ret_tokens)); + Ok(jobs) + } + } +} + +fn param_index(func: &FnItemWithSignature, name: &syn::Ident) -> syn::Result { + let pos = func.sig().inputs.iter().position(|arg| match arg { + FnArg::Receiver(_) => name == "self", + FnArg::Typed(pt) => matches!(&*pt.pat, syn::Pat::Ident(pi) if &pi.ident == name), + }); + pos.ok_or_else(|| { + syn::Error::new_spanned(name, format!("no parameter named `{}` in signature", name)) + }) +} + +/// Parses `name : ` from a flat token slice. +fn split_name_type(tokens: &[TokenTree2]) -> syn::Result<(syn::Ident, Vec)> { + let name = match tokens.first() { + Some(TokenTree2::Ident(id)) => id.clone(), + _ => return Err(err_tokens(tokens, "expected a parameter name")), + }; + match tokens.get(1) { + Some(TokenTree2::Punct(p)) if p.as_char() == ':' => {} + _ => return Err(err_tokens(tokens, "expected `:` after parameter name")), + } + Ok((name, tokens[2..].to_vec())) +} + +/// Parses `fn ( n0: t0 , ... ) -> ret` into `((name, ty_tokens)*, ret_tokens)`. +#[allow(clippy::type_complexity)] +fn parse_sig_attr( + tokens: &[TokenTree2], +) -> syn::Result<(Vec<(syn::Ident, Vec)>, Vec)> { + match tokens.first() { + Some(TokenTree2::Ident(id)) if id == "fn" => {} + _ => return Err(err_tokens(tokens, "expected `fn` in sig annotation")), + } + let arg_group = match tokens.get(1) { + Some(TokenTree2::Group(g)) if g.delimiter() == proc_macro2::Delimiter::Parenthesis => g, + _ => return Err(err_tokens(tokens, "expected `(..)` after `fn`")), + }; + + let mut args = Vec::new(); + let arg_tokens: Vec = arg_group.stream().into_iter().collect(); + for arg in split_top_level_commas(&arg_tokens) { + if arg.is_empty() { + continue; + } + args.push(split_name_type(&arg)?); + } + + // expect `->` then the return type + let mut rest = &tokens[2..]; + match (rest.first(), rest.get(1)) { + (Some(TokenTree2::Punct(a)), Some(TokenTree2::Punct(b))) + if a.as_char() == '-' && b.as_char() == '>' => + { + rest = &rest[2..]; + } + _ => { + return Err(err_tokens( + tokens, + "expected `->` and a return type in sig annotation", + )) + } + } + Ok((args, rest.to_vec())) +} + +/// Scans a single type expression, recording every refinement node together +/// with its type position. +fn scan_type(tokens: &[TokenTree2], path: &[usize], out: &mut Vec) -> syn::Result<()> { + if tokens.is_empty() { + return Ok(()); + } + + // A refinement node is exactly a brace-delimited group. + if tokens.len() == 1 { + if let TokenTree2::Group(g) = &tokens[0] { + if g.delimiter() == proc_macro2::Delimiter::Brace { + let (binder, binder_ty, formula) = split_refinement(g.stream())?; + out.push(Refinement { + path: path.to_vec(), + binder, + binder_ty: binder_ty.iter().cloned().collect(), + formula, + }); + // The refinement's own type sits at the same position; descend + // into it to find further nested refinements. + scan_type(&binder_ty, path, out)?; + return Ok(()); + } + } + } + + // A nominal type `Name` (`Box` included). + if let TokenTree2::Ident(_) = &tokens[0] { + if let Some(TokenTree2::Punct(p)) = tokens.get(1) { + if p.as_char() == '<' { + let mut type_idx = 0; + for arg in split_angle_args(&tokens[2..]) { + if is_lifetime(&arg) { + continue; + } + let mut child = path.to_vec(); + child.push(type_idx); + scan_type(&arg, &child, out)?; + type_idx += 1; + } + } + } + } + + Ok(()) +} + +/// Splits `{ binder : ty | formula }` contents into its parts. +fn split_refinement( + stream: TokenStream2, +) -> syn::Result<(syn::Ident, Vec, TokenStream2)> { + let toks: Vec = stream.into_iter().collect(); + let bar = toks + .iter() + .position(|tt| matches!(tt, TokenTree2::Punct(p) if p.as_char() == '|')) + .ok_or_else(|| err_tokens(&toks, "refinement type must contain `|`"))?; + let (binder, binder_ty) = split_name_type(&toks[..bar])?; + let formula: TokenStream2 = toks[bar + 1..].iter().cloned().collect(); + Ok((binder, binder_ty, formula)) +} + +/// Splits the tokens following an opening `<` at top level by commas, stopping +/// at the matching `>`. +fn split_angle_args(tokens: &[TokenTree2]) -> Vec> { + let mut args = Vec::new(); + let mut cur = Vec::new(); + let mut depth = 1usize; + for tt in tokens { + if let TokenTree2::Punct(p) = tt { + match p.as_char() { + '<' => { + depth += 1; + cur.push(tt.clone()); + continue; + } + '>' => { + depth -= 1; + if depth == 0 { + break; + } + cur.push(tt.clone()); + continue; + } + ',' if depth == 1 => { + args.push(std::mem::take(&mut cur)); + continue; + } + _ => {} + } + } + cur.push(tt.clone()); + } + if !cur.is_empty() { + args.push(cur); + } + args +} + +fn split_top_level_commas(tokens: &[TokenTree2]) -> Vec> { + let mut out = Vec::new(); + let mut cur = Vec::new(); + let mut depth = 0i32; + for tt in tokens { + if let TokenTree2::Punct(p) = tt { + match p.as_char() { + '<' => depth += 1, + '>' => depth -= 1, + ',' if depth == 0 => { + out.push(std::mem::take(&mut cur)); + continue; + } + _ => {} + } + } + cur.push(tt.clone()); + } + out.push(cur); + out +} + +fn is_lifetime(tokens: &[TokenTree2]) -> bool { + matches!(tokens.first(), Some(TokenTree2::Punct(p)) if p.as_char() == '\'') +} + +fn err_tokens(tokens: &[TokenTree2], msg: &str) -> syn::Error { + let stream: TokenStream2 = tokens.iter().cloned().collect(); + syn::Error::new_spanned(stream, msg) +} + +fn rewrite_self_in_tokens(tokens: TokenStream2) -> TokenStream2 { + tokens + .into_iter() + .map(|tt| match tt { + TokenTree2::Ident(id) if id == "self" => TokenTree2::Ident(format_ident!("self_")), + TokenTree2::Group(g) => { + let inner = rewrite_self_in_tokens(g.stream()); + TokenTree2::Group(proc_macro2::Group::new(g.delimiter(), inner)) + } + other => other, + }) + .collect() +} + +fn refine_fn_name(func: &FnItemWithSignature, path: &[usize]) -> syn::Ident { + let pos = path + .iter() + .map(|i| i.to_string()) + .collect::>() + .join("_"); + format_ident!("_thrust_refine_{}_{}", func.sig().ident, pos) +} + +fn refine_formula_fn( + func: &FnItemWithSignature, + outer_context: Option<&FnOuterItem>, + r: &Refinement, +) -> TokenStream2 { + let name = refine_fn_name(func, &r.path); + let def_generics = generic_params_tokens(&func.sig().generics); + let model_params = fn_params_with_model_ty(&func.sig().inputs); + let model_preds = model_where_predicates(func, outer_context); + let extended_where = extended_where_clause(func, &model_preds); + let binder = &r.binder; + let binder_ty = &r.binder_ty; + let formula = &r.formula; + + quote! { + #[allow(unused_variables)] + #[allow(non_snake_case)] + #[thrust::formula_fn] + fn #name #def_generics( + #binder: <#binder_ty as thrust_models::Model>::Ty, + #model_params + ) -> bool #extended_where { + #formula + } + } +} + +fn refine_path_stmt(func: &FnItemWithSignature, r: &Refinement) -> TokenStream2 { + let name = refine_fn_name(func, &r.path); + let turbofish = generic_turbofish(&func.sig().generics); + let path_prefix = if func.sig().receiver().is_some() { + quote!(Self::) + } else { + quote!() + }; + let pos = r + .path + .iter() + .map(|i| proc_macro2::Literal::usize_unsuffixed(*i)) + .collect::>(); + quote! { + #[thrust::refine(#(#pos),*)] + #path_prefix #name #turbofish; + } +} + fn mentions_self(sig: &syn::Signature) -> bool { struct Visitor { mentions_self: bool, From 8c3343dc1783e8e4b2e912805a27b614e4ca2f89 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 26 May 2026 16:09:32 +0000 Subject: [PATCH 02/18] Model refinement type positions with a structured TypePosition Replace the bare `Vec` type position with `rty::TypePosition` (a `TypePositionRoot` of `Param(idx)` / `Return`, plus a projection of nested type-argument indices). The `#[thrust::refine(..)]` attribute now uses the `result` keyword to select the return instead of the parameter count, which was unintuitive. Adds `Display` (`$1`, `result.0`). --- src/analyze.rs | 6 +-- src/analyze/annot.rs | 50 ++++++++++++++++------- src/refine/template.rs | 46 +++++++++++----------- src/rty.rs | 69 +++++++++++++++++++++++++------- thrust-macros/src/lib.rs | 85 +++++++++++++++++++++++++--------------- 5 files changed, 170 insertions(+), 86 deletions(-) diff --git a/src/analyze.rs b/src/analyze.rs index 65df3712..2111c39f 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -792,7 +792,7 @@ impl<'tcx> Analyzer<'tcx> { /// Collects every `#[thrust::refine(..)]` path statement in the function /// body, returning each `(type position, formula_fn DefId)`. - fn extract_refine_paths(&self, local_def_id: LocalDefId) -> Vec<(Vec, DefId)> { + fn extract_refine_paths(&self, local_def_id: LocalDefId) -> Vec<(rty::TypePosition, DefId)> { let mut out = Vec::new(); let Some(body) = self.tcx.hir_maybe_body_owned_by(local_def_id) else { return out; @@ -812,7 +812,7 @@ impl<'tcx> Analyzer<'tcx> { continue; }; let ts = analyze::annot::extract_annot_tokens(attr.clone()); - let position = analyze::annot::parse_position(&ts); + let position = analyze::annot::parse_type_position(&ts); let rustc_hir::StmtKind::Semi(expr) = stmt.kind else { self.tcx.dcx().span_err( @@ -846,7 +846,7 @@ impl<'tcx> Analyzer<'tcx> { &self, local_def_id: LocalDefId, generic_args: mir_ty::GenericArgsRef<'tcx>, - ) -> Vec<(Vec, rty::Refinement)> { + ) -> Vec<(rty::TypePosition, rty::Refinement)> { let mut out = Vec::new(); for (position, def_id) in self.extract_refine_paths(local_def_id) { let Some(formula_def_id) = def_id.as_local() else { diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index b969fb56..96854d32 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -219,31 +219,53 @@ pub fn extract_annot_tokens(attr: Attribute) -> TokenStream { d.tokens } -/// Parses a comma-separated list of integer literals (a "type position") from -/// the tokens of a `#[thrust::refine(..)]` attribute. -pub fn parse_position(ts: &TokenStream) -> Vec { +/// Parses a [`rty::TypePosition`] from the tokens of a `#[thrust::refine(..)]` +/// attribute. +/// +/// The first token is the root: the keyword `result` for the return, or an +/// integer for a parameter index. The remaining comma-separated integers form +/// the projection into nested type arguments. For example `result, 0` is the +/// first type-argument of the return, and `1` is the second parameter. +pub fn parse_type_position(ts: &TokenStream) -> rty::TypePosition { use rustc_ast::token::{LitKind, TokenKind}; use rustc_ast::tokenstream::TokenTree; - let mut out = Vec::new(); - for tt in ts.iter() { + let parse_int = |lit: &rustc_ast::token::Lit| -> usize { + assert_eq!( + lit.kind, + LitKind::Integer, + "expected an integer in refine position" + ); + lit.symbol + .as_str() + .parse() + .expect("invalid integer in refine position") + }; + + let mut iter = ts.iter(); + let root = match iter.next() { + Some(TokenTree::Token(t, _)) => match &t.kind { + TokenKind::Ident(sym, _) if sym.as_str() == "result" => rty::TypePositionRoot::Return, + TokenKind::Literal(lit) => { + rty::TypePositionRoot::Param(rty::FunctionParamIdx::from(parse_int(lit))) + } + _ => panic!("unexpected refine position root: {:?}", t), + }, + _ => panic!("empty refine position"), + }; + + let mut projection = Vec::new(); + for tt in iter { match tt { TokenTree::Token(t, _) => match &t.kind { TokenKind::Comma => {} - TokenKind::Literal(lit) if lit.kind == LitKind::Integer => { - out.push( - lit.symbol - .as_str() - .parse() - .expect("invalid integer in refine position"), - ); - } + TokenKind::Literal(lit) => projection.push(parse_int(lit)), _ => panic!("unexpected token in refine position: {:?}", t), }, _ => panic!("unexpected token tree in refine position"), } } - out + rty::TypePosition::new(root, projection) } pub fn split_param(ts: &TokenStream) -> (Ident, TokenStream) { diff --git a/src/refine/template.rs b/src/refine/template.rs index 074b8bb7..04fb8eae 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -566,35 +566,35 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> { self } - /// Installs a refinement at a function-type position. The first index of - /// `path` selects a parameter (by index) or the return slot (when it equals - /// the parameter count); the remaining indices descend into the slot's type. + /// Installs a refinement at a [`rty::TypePosition`]. The root selects a + /// parameter or the return slot; the projection then descends into the + /// slot's nested type arguments. pub fn refine( &mut self, - path: &[usize], + position: &rty::TypePosition, refinement: rty::Refinement, ) -> &mut Self { - let (&slot, sub) = path.split_first().expect("refine path must be non-empty"); - let n = self.param_tys.len(); - if slot < n { - let idx = rty::FunctionParamIdx::from(slot); - if !self.param_rtys.contains_key(&idx) { - let ty = self.inner.build(self.param_tys[slot].ty).vacuous(); - self.param_rtys.insert(idx, rty::RefinedType::unrefined(ty)); + match position.root { + rty::TypePositionRoot::Param(idx) => { + if !self.param_rtys.contains_key(&idx) { + let ty = self.inner.build(self.param_tys[idx.index()].ty).vacuous(); + self.param_rtys.insert(idx, rty::RefinedType::unrefined(ty)); + } + self.param_rtys + .get_mut(&idx) + .unwrap() + .install_refinement_at(&position.projection, refinement); } - self.param_rtys - .get_mut(&idx) - .unwrap() - .install_refinement_at(sub, refinement); - } else { - if self.ret_rty.is_none() { - let ty = self.inner.build(self.ret_ty).vacuous(); - self.ret_rty = Some(rty::RefinedType::unrefined(ty)); + rty::TypePositionRoot::Return => { + if self.ret_rty.is_none() { + let ty = self.inner.build(self.ret_ty).vacuous(); + self.ret_rty = Some(rty::RefinedType::unrefined(ty)); + } + self.ret_rty + .as_mut() + .unwrap() + .install_refinement_at(&position.projection, refinement); } - self.ret_rty - .as_mut() - .unwrap() - .install_refinement_at(sub, refinement); } self } diff --git a/src/rty.rs b/src/rty.rs index a7c543c0..e80ef866 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -83,6 +83,53 @@ where } } +/// Selects a parameter or the return of a function type — the root of a +/// [`TypePosition`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TypePositionRoot { + Param(FunctionParamIdx), + Return, +} + +impl std::fmt::Display for TypePositionRoot { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + TypePositionRoot::Param(idx) => write!(f, "{}", idx), + TypePositionRoot::Return => f.write_str("result"), + } + } +} + +/// A position addressing a sub-type within a function type, used to attach a +/// refinement. +/// +/// The [`root`](Self::root) selects a parameter or the return; the +/// [`projection`](Self::projection) then descends into nested type arguments +/// (enum type-arguments, `Box` pointee). For example, `result.0` addresses the +/// first type-argument of the return type, and `$1` addresses the second +/// parameter. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TypePosition { + pub root: TypePositionRoot, + pub projection: Vec, +} + +impl TypePosition { + pub fn new(root: TypePositionRoot, projection: Vec) -> Self { + TypePosition { root, projection } + } +} + +impl std::fmt::Display for TypePosition { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", self.root)?; + for p in &self.projection { + write!(f, ".{}", p)?; + } + Ok(()) + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] pub enum FunctionAbi { #[default] @@ -1476,22 +1523,22 @@ where } impl RefinedType { - /// Installs `refinement` at the given type position, descending through the - /// function type (parameters then return), enum type arguments, and `Box` - /// pointees. An empty path replaces the refinement at the current node. + /// Installs `refinement` at the given projection — a path of nested + /// type-argument indices descending through enum type arguments and `Box` + /// pointees. An empty projection replaces the refinement at this node. pub fn install_refinement_at( &mut self, - path: &[usize], + projection: &[usize], refinement: Refinement, ) { - let Some((&step, rest)) = path.split_first() else { + let Some((&step, rest)) = projection.split_first() else { self.refinement = refinement; return; }; match &mut self.ty { Type::Enum(e) => { let arg = e.args.get_mut(TypeParamIdx::from(step)).unwrap_or_else(|| { - panic!("refine position {} out of range for enum type", step) + panic!("refine projection {} out of range for enum type", step) }); arg.install_refinement_at(rest, refinement); } @@ -1499,16 +1546,8 @@ impl RefinedType { assert_eq!(step, 0, "Box type position must be 0"); p.elem.install_refinement_at(rest, refinement); } - Type::Function(f) => { - let n = f.params.len(); - if step < n { - f.params[FunctionParamIdx::from(step)].install_refinement_at(rest, refinement); - } else { - f.ret.install_refinement_at(rest, refinement); - } - } ty => panic!( - "unsupported type at refine position step {}: {:?}", + "unsupported type at refine projection step {}: {:?}", step, ty ), } diff --git a/thrust-macros/src/lib.rs b/thrust-macros/src/lib.rs index 3afc8409..c9ddb2d1 100644 --- a/thrust-macros/src/lib.rs +++ b/thrust-macros/src/lib.rs @@ -577,14 +577,23 @@ impl ExpandedTokens { // These lower refinement types (e.g. `List<{ v: i32 | v > 0 }>`) into // `#[thrust::formula_fn]`s plus positioned `#[thrust::refine(..)]` path // statements injected into the function body. The "type position" addresses -// into the function type: the first index selects a parameter (by index) or -// the return slot (== parameter count), and subsequent indices descend into -// generic arguments (enum args / `Box` pointee). +// into the function type: its root selects a parameter (by index) or the +// return (the `result` keyword), and the projection (the remaining indices) +// descends into generic arguments (enum args / `Box` pointee). For example, +// `#[thrust::refine(result, 0)]` is the first type-argument of the return. // --------------------------------------------------------------------------- +/// Root of a refinement's type position: a parameter (by index) or the return. +#[derive(Clone, Copy)] +enum RefineRoot { + Param(usize), + Return, +} + #[derive(Clone)] struct Refinement { - path: Vec, + root: RefineRoot, + projection: Vec, binder: syn::Ident, binder_ty: TokenStream2, formula: TokenStream2, @@ -633,7 +642,7 @@ fn expand_refine(kind: RefineKind, attr: TokenStream, item: TokenStream) -> Toke let mut refinements = Vec::new(); for (root, ty_tokens) in jobs { - if let Err(e) = scan_type(&ty_tokens, &root, &mut refinements) { + if let Err(e) = scan_type(&ty_tokens, root, &[], &mut refinements) { let err = e.to_compile_error(); return quote! { #err #func }.into(); } @@ -677,28 +686,27 @@ fn expand_refine(kind: RefineKind, attr: TokenStream, item: TokenStream) -> Toke .into() } -/// Builds `(root_path, type_tokens)` jobs to scan from the attribute tokens. +/// Builds `(root, type_tokens)` jobs to scan from the attribute tokens. fn build_refine_jobs( kind: RefineKind, func: &FnItemWithSignature, attr_tokens: &[TokenTree2], -) -> syn::Result, Vec)>> { - let param_count = func.sig().inputs.len(); +) -> syn::Result)>> { match kind { RefineKind::Param => { let (name, ty_tokens) = split_name_type(attr_tokens)?; let idx = param_index(func, &name)?; - Ok(vec![(vec![idx], ty_tokens)]) + Ok(vec![(RefineRoot::Param(idx), ty_tokens)]) } - RefineKind::Ret => Ok(vec![(vec![param_count], attr_tokens.to_vec())]), + RefineKind::Ret => Ok(vec![(RefineRoot::Return, attr_tokens.to_vec())]), RefineKind::Sig => { let (args, ret_tokens) = parse_sig_attr(attr_tokens)?; let mut jobs = Vec::new(); for (name, ty_tokens) in args { let idx = param_index(func, &name)?; - jobs.push((vec![idx], ty_tokens)); + jobs.push((RefineRoot::Param(idx), ty_tokens)); } - jobs.push((vec![param_count], ret_tokens)); + jobs.push((RefineRoot::Return, ret_tokens)); Ok(jobs) } } @@ -769,8 +777,14 @@ fn parse_sig_attr( } /// Scans a single type expression, recording every refinement node together -/// with its type position. -fn scan_type(tokens: &[TokenTree2], path: &[usize], out: &mut Vec) -> syn::Result<()> { +/// with its type position (a fixed `root` plus the `projection` accumulated +/// while descending into nested type arguments). +fn scan_type( + tokens: &[TokenTree2], + root: RefineRoot, + projection: &[usize], + out: &mut Vec, +) -> syn::Result<()> { if tokens.is_empty() { return Ok(()); } @@ -781,14 +795,15 @@ fn scan_type(tokens: &[TokenTree2], path: &[usize], out: &mut Vec) - if g.delimiter() == proc_macro2::Delimiter::Brace { let (binder, binder_ty, formula) = split_refinement(g.stream())?; out.push(Refinement { - path: path.to_vec(), + root, + projection: projection.to_vec(), binder, binder_ty: binder_ty.iter().cloned().collect(), formula, }); // The refinement's own type sits at the same position; descend // into it to find further nested refinements. - scan_type(&binder_ty, path, out)?; + scan_type(&binder_ty, root, projection, out)?; return Ok(()); } } @@ -803,9 +818,9 @@ fn scan_type(tokens: &[TokenTree2], path: &[usize], out: &mut Vec) - if is_lifetime(&arg) { continue; } - let mut child = path.to_vec(); + let mut child = projection.to_vec(); child.push(type_idx); - scan_type(&arg, &child, out)?; + scan_type(&arg, root, &child, out)?; type_idx += 1; } } @@ -911,12 +926,14 @@ fn rewrite_self_in_tokens(tokens: TokenStream2) -> TokenStream2 { .collect() } -fn refine_fn_name(func: &FnItemWithSignature, path: &[usize]) -> syn::Ident { - let pos = path - .iter() - .map(|i| i.to_string()) - .collect::>() - .join("_"); +fn refine_fn_name(func: &FnItemWithSignature, r: &Refinement) -> syn::Ident { + let mut pos = match r.root { + RefineRoot::Param(i) => format!("p{}", i), + RefineRoot::Return => "ret".to_string(), + }; + for p in &r.projection { + pos.push_str(&format!("_{}", p)); + } format_ident!("_thrust_refine_{}_{}", func.sig().ident, pos) } @@ -925,7 +942,7 @@ fn refine_formula_fn( outer_context: Option<&FnOuterItem>, r: &Refinement, ) -> TokenStream2 { - let name = refine_fn_name(func, &r.path); + let name = refine_fn_name(func, r); let def_generics = generic_params_tokens(&func.sig().generics); let model_params = fn_params_with_model_ty(&func.sig().inputs); let model_preds = model_where_predicates(func, outer_context); @@ -948,20 +965,26 @@ fn refine_formula_fn( } fn refine_path_stmt(func: &FnItemWithSignature, r: &Refinement) -> TokenStream2 { - let name = refine_fn_name(func, &r.path); + let name = refine_fn_name(func, r); let turbofish = generic_turbofish(&func.sig().generics); let path_prefix = if func.sig().receiver().is_some() { quote!(Self::) } else { quote!() }; - let pos = r - .path + let root = match r.root { + RefineRoot::Param(i) => { + let lit = proc_macro2::Literal::usize_unsuffixed(i); + quote!(#lit) + } + RefineRoot::Return => quote!(result), + }; + let projection = r + .projection .iter() - .map(|i| proc_macro2::Literal::usize_unsuffixed(*i)) - .collect::>(); + .map(|i| proc_macro2::Literal::usize_unsuffixed(*i)); quote! { - #[thrust::refine(#(#pos),*)] + #[thrust::refine(#root #(, #projection)*)] #path_prefix #name #turbofish; } } From 1275205a335c6e5dd51ef09cf8eea0d967fc1621 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 27 May 2026 01:56:35 +0000 Subject: [PATCH 03/18] Redesign TypePosition as a flat sequence of TypePositionStep Replace the previous TypePositionRoot + projection: Vec split with a flat Vec, where each step is one of: - Param(FunctionParamIdx): navigate into a function type's parameter - Return: navigate into a function type's return slot - TypeArg(usize): navigate into a generic type's type argument This makes the path representation uniform across all type levels, enabling future support for refinements on positions inside higher-order function types (e.g. [$0, result] for the return type of a function-typed first parameter). The attribute encoding changes accordingly: - result (ident) => Return - integer i => Param(i) - bracket group [i] => TypeArg(i) The macro-side RefineRoot+projection split is similarly replaced with a flat Vec in the Refinement struct, and scan_type now threads the full steps Vec through recursive calls. https://claude.ai/code/session_01Km9xwaVaGQjnAy1iHsPewa --- src/analyze/annot.rs | 62 ++++++++++++------- src/refine/template.rs | 32 ++++++---- src/rty.rs | 129 +++++++++++++++++++++++++++------------ thrust-macros/src/lib.rs | 101 +++++++++++++++++------------- 4 files changed, 210 insertions(+), 114 deletions(-) diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index 96854d32..c38fa248 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -222,12 +222,19 @@ pub fn extract_annot_tokens(attr: Attribute) -> TokenStream { /// Parses a [`rty::TypePosition`] from the tokens of a `#[thrust::refine(..)]` /// attribute. /// -/// The first token is the root: the keyword `result` for the return, or an -/// integer for a parameter index. The remaining comma-separated integers form -/// the projection into nested type arguments. For example `result, 0` is the -/// first type-argument of the return, and `1` is the second parameter. +/// Tokens are comma-separated steps. Each step is one of: +/// - The keyword `result` → [`rty::TypePositionStep::Return`] (navigate to a +/// function type's return slot). +/// - An integer literal `i` → [`rty::TypePositionStep::Param`]`(i)` (navigate +/// to the `i`-th parameter of a function type). +/// - A bracket group `[i]` → [`rty::TypePositionStep::TypeArg`]`(i)` (navigate +/// to the `i`-th type argument of a generic type such as an enum or `Box`). +/// +/// Examples: `result` is the return; `0` is the first parameter; `0, [0]` is +/// the first type-argument of the first parameter; `0, result` is the return of +/// a function-typed first parameter. pub fn parse_type_position(ts: &TokenStream) -> rty::TypePosition { - use rustc_ast::token::{LitKind, TokenKind}; + use rustc_ast::token::{Delimiter, LitKind, TokenKind}; use rustc_ast::tokenstream::TokenTree; let parse_int = |lit: &rustc_ast::token::Lit| -> usize { @@ -242,30 +249,43 @@ pub fn parse_type_position(ts: &TokenStream) -> rty::TypePosition { .expect("invalid integer in refine position") }; - let mut iter = ts.iter(); - let root = match iter.next() { - Some(TokenTree::Token(t, _)) => match &t.kind { - TokenKind::Ident(sym, _) if sym.as_str() == "result" => rty::TypePositionRoot::Return, - TokenKind::Literal(lit) => { - rty::TypePositionRoot::Param(rty::FunctionParamIdx::from(parse_int(lit))) - } - _ => panic!("unexpected refine position root: {:?}", t), - }, - _ => panic!("empty refine position"), - }; - - let mut projection = Vec::new(); - for tt in iter { + let mut steps = Vec::new(); + for tt in ts.iter() { match tt { TokenTree::Token(t, _) => match &t.kind { TokenKind::Comma => {} - TokenKind::Literal(lit) => projection.push(parse_int(lit)), + TokenKind::Ident(sym, _) if sym.as_str() == "result" => { + steps.push(rty::TypePositionStep::Return); + } + TokenKind::Literal(lit) => { + steps.push(rty::TypePositionStep::Param(rty::FunctionParamIdx::from( + parse_int(lit), + ))); + } _ => panic!("unexpected token in refine position: {:?}", t), }, + TokenTree::Delimited(_, _, Delimiter::Bracket, inner) => { + let mut inner_iter = inner.iter(); + let i = match inner_iter.next() { + Some(TokenTree::Token(t, _)) => match &t.kind { + TokenKind::Literal(lit) => parse_int(lit), + _ => panic!("expected integer inside [..] refine step: {:?}", t), + }, + _ => panic!("expected integer inside [..] refine step"), + }; + assert!( + inner_iter.next().is_none(), + "expected exactly one integer inside [..] refine step" + ); + steps.push(rty::TypePositionStep::TypeArg(i)); + } _ => panic!("unexpected token tree in refine position"), } } - rty::TypePosition::new(root, projection) + + assert!(!steps.is_empty(), "empty refine position"); + let first = steps.remove(0); + rty::TypePosition::new(first, steps) } pub fn split_param(ts: &TokenStream) -> (Ident, TokenStream) { diff --git a/src/refine/template.rs b/src/refine/template.rs index 04fb8eae..50408612 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -566,26 +566,33 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> { self } - /// Installs a refinement at a [`rty::TypePosition`]. The root selects a - /// parameter or the return slot; the projection then descends into the - /// slot's nested type arguments. + /// Installs a refinement at a [`rty::TypePosition`]. + /// + /// The first step must be [`rty::TypePositionStep::Param`] or + /// [`rty::TypePositionStep::Return`]; the remaining steps are forwarded to + /// [`rty::RefinedType::install_refinement_at`]. pub fn refine( &mut self, position: &rty::TypePosition, refinement: rty::Refinement, ) -> &mut Self { - match position.root { - rty::TypePositionRoot::Param(idx) => { - if !self.param_rtys.contains_key(&idx) { + let (first, rest) = match position.steps().split_first() { + Some(pair) => pair, + None => panic!("empty TypePosition"), + }; + match first { + rty::TypePositionStep::Param(idx) => { + if !self.param_rtys.contains_key(idx) { let ty = self.inner.build(self.param_tys[idx.index()].ty).vacuous(); - self.param_rtys.insert(idx, rty::RefinedType::unrefined(ty)); + self.param_rtys + .insert(*idx, rty::RefinedType::unrefined(ty)); } self.param_rtys - .get_mut(&idx) + .get_mut(idx) .unwrap() - .install_refinement_at(&position.projection, refinement); + .install_refinement_at(rest, refinement); } - rty::TypePositionRoot::Return => { + rty::TypePositionStep::Return => { if self.ret_rty.is_none() { let ty = self.inner.build(self.ret_ty).vacuous(); self.ret_rty = Some(rty::RefinedType::unrefined(ty)); @@ -593,7 +600,10 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> { self.ret_rty .as_mut() .unwrap() - .install_refinement_at(&position.projection, refinement); + .install_refinement_at(rest, refinement); + } + rty::TypePositionStep::TypeArg(_) => { + panic!("TypePosition must start with Param or Return, not TypeArg"); } } self diff --git a/src/rty.rs b/src/rty.rs index e80ef866..b0f10f3d 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -83,48 +83,81 @@ where } } -/// Selects a parameter or the return of a function type — the root of a -/// [`TypePosition`]. +/// One step in a [`TypePosition`] path. +/// +/// A path is a sequence of steps that addresses a sub-type within a +/// (potentially nested) function signature: +/// - [`Param`](Self::Param) / [`Return`](Self::Return) navigate into a +/// function type's parameter or return slot respectively. +/// - [`TypeArg`](Self::TypeArg) navigates into the `i`-th type argument of a +/// generic type (enum, `Box`, etc.). +/// +/// Using distinct variants for function navigation ([`Param`](Self::Param), +/// [`Return`](Self::Return)) and generic-arg navigation +/// ([`TypeArg`](Self::TypeArg)) allows the same path representation to address +/// positions inside higher-order function types. For example, `[$0, result]` +/// addresses the return type of a function-typed first parameter. #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum TypePositionRoot { +pub enum TypePositionStep { + /// Navigate to the `i`-th parameter of a function type. Param(FunctionParamIdx), + /// Navigate to the return type of a function type. Return, + /// Navigate to the `i`-th type argument of a generic type (enum, `Box`, …). + TypeArg(usize), } -impl std::fmt::Display for TypePositionRoot { +impl std::fmt::Display for TypePositionStep { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { - TypePositionRoot::Param(idx) => write!(f, "{}", idx), - TypePositionRoot::Return => f.write_str("result"), + TypePositionStep::Param(idx) => write!(f, "{}", idx), + TypePositionStep::Return => f.write_str("result"), + TypePositionStep::TypeArg(i) => write!(f, "[{}]", i), } } } -/// A position addressing a sub-type within a function type, used to attach a -/// refinement. +/// A path addressing a sub-type in a function's type signature, used to attach +/// a refinement. /// -/// The [`root`](Self::root) selects a parameter or the return; the -/// [`projection`](Self::projection) then descends into nested type arguments -/// (enum type-arguments, `Box` pointee). For example, `result.0` addresses the -/// first type-argument of the return type, and `$1` addresses the second -/// parameter. +/// The first step must be [`TypePositionStep::Param`] or +/// [`TypePositionStep::Return`] (selecting which slot of the top-level function +/// type to enter). Subsequent steps can freely combine +/// [`TypePositionStep::Param`] / [`TypePositionStep::Return`] (for +/// function-typed positions) and [`TypePositionStep::TypeArg`] (for generic +/// types), enabling positions inside higher-order function types. +/// +/// Examples (function `fn f(x: List) -> Box`): +/// - `[$0]` — parameter `x`. +/// - `[result]` — the return type. +/// - `[$0, [0]]` — the first type arg of `x`. +/// - `[result, [0]]` — the pointee of the `Box` return. +/// - `[$0, result]` — the return of a function-typed param `x`. #[derive(Debug, Clone, PartialEq, Eq)] pub struct TypePosition { - pub root: TypePositionRoot, - pub projection: Vec, + steps: Vec, } impl TypePosition { - pub fn new(root: TypePositionRoot, projection: Vec) -> Self { - TypePosition { root, projection } + pub fn new(first: TypePositionStep, rest: Vec) -> Self { + let mut steps = vec![first]; + steps.extend(rest); + TypePosition { steps } + } + + pub fn steps(&self) -> &[TypePositionStep] { + &self.steps } } impl std::fmt::Display for TypePosition { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}", self.root)?; - for p in &self.projection { - write!(f, ".{}", p)?; + let mut iter = self.steps.iter(); + if let Some(first) = iter.next() { + write!(f, "{}", first)?; + } + for s in iter { + write!(f, ".{}", s)?; } Ok(()) } @@ -1523,33 +1556,49 @@ where } impl RefinedType { - /// Installs `refinement` at the given projection — a path of nested - /// type-argument indices descending through enum type arguments and `Box` - /// pointees. An empty projection replaces the refinement at this node. + /// Installs `refinement` at the sub-type addressed by `steps`. + /// + /// An empty `steps` slice replaces the refinement at this node. Each step + /// in the slice navigates one level deeper: + /// - [`TypePositionStep::TypeArg`] descends into enum type arguments or the + /// `Box` pointee. + /// - [`TypePositionStep::Param`] / [`TypePositionStep::Return`] descend + /// into a function-typed position's parameter or return slot. pub fn install_refinement_at( &mut self, - projection: &[usize], + steps: &[TypePositionStep], refinement: Refinement, ) { - let Some((&step, rest)) = projection.split_first() else { + let Some((step, rest)) = steps.split_first() else { self.refinement = refinement; return; }; - match &mut self.ty { - Type::Enum(e) => { - let arg = e.args.get_mut(TypeParamIdx::from(step)).unwrap_or_else(|| { - panic!("refine projection {} out of range for enum type", step) - }); - arg.install_refinement_at(rest, refinement); - } - Type::Pointer(p) => { - assert_eq!(step, 0, "Box type position must be 0"); - p.elem.install_refinement_at(rest, refinement); - } - ty => panic!( - "unsupported type at refine projection step {}: {:?}", - step, ty - ), + match step { + TypePositionStep::TypeArg(i) => match &mut self.ty { + Type::Enum(e) => { + let arg = e.args.get_mut(TypeParamIdx::from(*i)).unwrap_or_else(|| { + panic!("refine step [{}] out of range for enum type", i) + }); + arg.install_refinement_at(rest, refinement); + } + Type::Pointer(p) => { + assert_eq!(*i, 0, "Box type position must be [0]"); + p.elem.install_refinement_at(rest, refinement); + } + ty => panic!("TypeArg step on unsupported type: {:?}", ty), + }, + TypePositionStep::Param(idx) => match &mut self.ty { + Type::Function(func) => { + func.params[*idx].install_refinement_at(rest, refinement); + } + ty => panic!("Param step on non-function type: {:?}", ty), + }, + TypePositionStep::Return => match &mut self.ty { + Type::Function(func) => { + func.ret.install_refinement_at(rest, refinement); + } + ty => panic!("Return step on non-function type: {:?}", ty), + }, } } } diff --git a/thrust-macros/src/lib.rs b/thrust-macros/src/lib.rs index c9ddb2d1..2f137dfd 100644 --- a/thrust-macros/src/lib.rs +++ b/thrust-macros/src/lib.rs @@ -583,17 +583,25 @@ impl ExpandedTokens { // `#[thrust::refine(result, 0)]` is the first type-argument of the return. // --------------------------------------------------------------------------- -/// Root of a refinement's type position: a parameter (by index) or the return. +/// One step in a refinement's type-position path. +/// +/// Mirrors [`rty::TypePositionStep`] on the plugin side and uses the same +/// attribute encoding: +/// - [`Param`](Self::Param) / [`Return`](Self::Return) navigate into a function +/// type; encoded as an integer literal / the `result` keyword. +/// - [`TypeArg`](Self::TypeArg) navigates into a generic type argument; encoded +/// as a bracket group `[i]`. #[derive(Clone, Copy)] -enum RefineRoot { +enum RefineStep { Param(usize), Return, + TypeArg(usize), } #[derive(Clone)] struct Refinement { - root: RefineRoot, - projection: Vec, + /// Full type-position path from the function root to the refined type. + steps: Vec, binder: syn::Ident, binder_ty: TokenStream2, formula: TokenStream2, @@ -641,8 +649,8 @@ fn expand_refine(kind: RefineKind, attr: TokenStream, item: TokenStream) -> Toke }; let mut refinements = Vec::new(); - for (root, ty_tokens) in jobs { - if let Err(e) = scan_type(&ty_tokens, root, &[], &mut refinements) { + for (root_steps, ty_tokens) in jobs { + if let Err(e) = scan_type(&ty_tokens, root_steps, &mut refinements) { let err = e.to_compile_error(); return quote! { #err #func }.into(); } @@ -686,27 +694,32 @@ fn expand_refine(kind: RefineKind, attr: TokenStream, item: TokenStream) -> Toke .into() } -/// Builds `(root, type_tokens)` jobs to scan from the attribute tokens. +/// Builds `(root_steps, type_tokens)` jobs to scan from the attribute tokens. +/// +/// Each job's `root_steps` contains the initial [`RefineStep`]s that fix the +/// position of the type expression within the function signature (e.g. +/// `[Param(0)]` for the first parameter). [`scan_type`] will append further +/// [`RefineStep::TypeArg`] steps as it descends into generic type arguments. fn build_refine_jobs( kind: RefineKind, func: &FnItemWithSignature, attr_tokens: &[TokenTree2], -) -> syn::Result)>> { +) -> syn::Result, Vec)>> { match kind { RefineKind::Param => { let (name, ty_tokens) = split_name_type(attr_tokens)?; let idx = param_index(func, &name)?; - Ok(vec![(RefineRoot::Param(idx), ty_tokens)]) + Ok(vec![(vec![RefineStep::Param(idx)], ty_tokens)]) } - RefineKind::Ret => Ok(vec![(RefineRoot::Return, attr_tokens.to_vec())]), + RefineKind::Ret => Ok(vec![(vec![RefineStep::Return], attr_tokens.to_vec())]), RefineKind::Sig => { let (args, ret_tokens) = parse_sig_attr(attr_tokens)?; let mut jobs = Vec::new(); for (name, ty_tokens) in args { let idx = param_index(func, &name)?; - jobs.push((RefineRoot::Param(idx), ty_tokens)); + jobs.push((vec![RefineStep::Param(idx)], ty_tokens)); } - jobs.push((RefineRoot::Return, ret_tokens)); + jobs.push((vec![RefineStep::Return], ret_tokens)); Ok(jobs) } } @@ -776,13 +789,16 @@ fn parse_sig_attr( Ok((args, rest.to_vec())) } -/// Scans a single type expression, recording every refinement node together -/// with its type position (a fixed `root` plus the `projection` accumulated -/// while descending into nested type arguments). +/// Scans a type expression and records every refinement node with its full +/// type-position path (`steps`). +/// +/// `steps` holds the path from the function root to the current type node. +/// When a refinement `{binder: ty | formula}` is found the current `steps` are +/// recorded; when descending into generic type arguments a +/// [`RefineStep::TypeArg`]`(i)` step is appended to `steps`. fn scan_type( tokens: &[TokenTree2], - root: RefineRoot, - projection: &[usize], + steps: Vec, out: &mut Vec, ) -> syn::Result<()> { if tokens.is_empty() { @@ -795,15 +811,13 @@ fn scan_type( if g.delimiter() == proc_macro2::Delimiter::Brace { let (binder, binder_ty, formula) = split_refinement(g.stream())?; out.push(Refinement { - root, - projection: projection.to_vec(), + steps: steps.clone(), binder, binder_ty: binder_ty.iter().cloned().collect(), formula, }); - // The refinement's own type sits at the same position; descend - // into it to find further nested refinements. - scan_type(&binder_ty, root, projection, out)?; + // Descend into the binder type for nested refinements. + scan_type(&binder_ty, steps, out)?; return Ok(()); } } @@ -818,9 +832,9 @@ fn scan_type( if is_lifetime(&arg) { continue; } - let mut child = projection.to_vec(); - child.push(type_idx); - scan_type(&arg, root, &child, out)?; + let mut child = steps.clone(); + child.push(RefineStep::TypeArg(type_idx)); + scan_type(&arg, child, out)?; type_idx += 1; } } @@ -927,13 +941,16 @@ fn rewrite_self_in_tokens(tokens: TokenStream2) -> TokenStream2 { } fn refine_fn_name(func: &FnItemWithSignature, r: &Refinement) -> syn::Ident { - let mut pos = match r.root { - RefineRoot::Param(i) => format!("p{}", i), - RefineRoot::Return => "ret".to_string(), - }; - for p in &r.projection { - pos.push_str(&format!("_{}", p)); - } + let pos = r + .steps + .iter() + .map(|s| match s { + RefineStep::Param(i) => format!("p{}", i), + RefineStep::Return => "ret".to_string(), + RefineStep::TypeArg(i) => format!("t{}", i), + }) + .collect::>() + .join("_"); format_ident!("_thrust_refine_{}_{}", func.sig().ident, pos) } @@ -972,19 +989,19 @@ fn refine_path_stmt(func: &FnItemWithSignature, r: &Refinement) -> TokenStream2 } else { quote!() }; - let root = match r.root { - RefineRoot::Param(i) => { - let lit = proc_macro2::Literal::usize_unsuffixed(i); + let encoded_steps = r.steps.iter().map(|s| match s { + RefineStep::Param(i) => { + let lit = proc_macro2::Literal::usize_unsuffixed(*i); quote!(#lit) } - RefineRoot::Return => quote!(result), - }; - let projection = r - .projection - .iter() - .map(|i| proc_macro2::Literal::usize_unsuffixed(*i)); + RefineStep::Return => quote!(result), + RefineStep::TypeArg(i) => { + let lit = proc_macro2::Literal::usize_unsuffixed(*i); + quote!([#lit]) + } + }); quote! { - #[thrust::refine(#root #(, #projection)*)] + #[thrust::refine(#(#encoded_steps),*)] #path_prefix #name #turbofish; } } From 0e9f0e33782ac3785b5f5d8dc666e2e097614b28 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 27 May 2026 08:49:56 +0000 Subject: [PATCH 04/18] Rename thrust::refine attribute to thrust::refinement_path Align the macro-emitted attribute with the existing requires_path / ensures_path convention: the _path suffix conveys that the attribute's target is a path to a formula_fn. The symbol-path helper follows suit (refine_path -> refinement_path_path). https://claude.ai/code/session_01Km9xwaVaGQjnAy1iHsPewa --- src/analyze.rs | 10 +++++----- src/analyze/annot.rs | 8 ++++---- thrust-macros/src/lib.rs | 15 ++++++++------- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/analyze.rs b/src/analyze.rs index 2111c39f..e16fb08b 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -790,8 +790,8 @@ impl<'tcx> Analyzer<'tcx> { ensure_annot } - /// Collects every `#[thrust::refine(..)]` path statement in the function - /// body, returning each `(type position, formula_fn DefId)`. + /// Collects every `#[thrust::refinement_path(..)]` path statement in the + /// function body, returning each `(type position, formula_fn DefId)`. fn extract_refine_paths(&self, local_def_id: LocalDefId) -> Vec<(rty::TypePosition, DefId)> { let mut out = Vec::new(); let Some(body) = self.tcx.hir_maybe_body_owned_by(local_def_id) else { @@ -800,7 +800,7 @@ impl<'tcx> Analyzer<'tcx> { let rustc_hir::ExprKind::Block(block, _) = body.value.kind else { return out; }; - let attr_path = analyze::annot::refine_path(); + let attr_path = analyze::annot::refinement_path_path(); let typeck = self.tcx.typeck(local_def_id); for stmt in block.stmts { let Some(attr) = self @@ -840,8 +840,8 @@ impl<'tcx> Analyzer<'tcx> { out } - /// Resolves every `#[thrust::refine(..)]` annotation into a positioned - /// refinement, by translating the referenced formula function. + /// Resolves every `#[thrust::refinement_path(..)]` annotation into a + /// positioned refinement, by translating the referenced formula function. pub fn extract_refine_annots( &self, local_def_id: LocalDefId, diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index c38fa248..5ea243c0 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -61,8 +61,8 @@ pub fn ensures_path_path() -> [Symbol; 2] { [Symbol::intern("thrust"), Symbol::intern("ensures_path")] } -pub fn refine_path() -> [Symbol; 2] { - [Symbol::intern("thrust"), Symbol::intern("refine")] +pub fn refinement_path_path() -> [Symbol; 2] { + [Symbol::intern("thrust"), Symbol::intern("refinement_path")] } pub fn model_ty_path() -> [Symbol; 3] { @@ -219,8 +219,8 @@ pub fn extract_annot_tokens(attr: Attribute) -> TokenStream { d.tokens } -/// Parses a [`rty::TypePosition`] from the tokens of a `#[thrust::refine(..)]` -/// attribute. +/// Parses a [`rty::TypePosition`] from the tokens of a +/// `#[thrust::refinement_path(..)]` attribute. /// /// Tokens are comma-separated steps. Each step is one of: /// - The keyword `result` → [`rty::TypePositionStep::Return`] (navigate to a diff --git a/thrust-macros/src/lib.rs b/thrust-macros/src/lib.rs index 2f137dfd..ea27b3c7 100644 --- a/thrust-macros/src/lib.rs +++ b/thrust-macros/src/lib.rs @@ -575,12 +575,13 @@ impl ExpandedTokens { // Refinement-type annotations: `param`, `ret`, `sig`. // // These lower refinement types (e.g. `List<{ v: i32 | v > 0 }>`) into -// `#[thrust::formula_fn]`s plus positioned `#[thrust::refine(..)]` path -// statements injected into the function body. The "type position" addresses -// into the function type: its root selects a parameter (by index) or the -// return (the `result` keyword), and the projection (the remaining indices) -// descends into generic arguments (enum args / `Box` pointee). For example, -// `#[thrust::refine(result, 0)]` is the first type-argument of the return. +// `#[thrust::formula_fn]`s plus positioned `#[thrust::refinement_path(..)]` +// path statements injected into the function body. The "type position" +// addresses into the function type: a parameter (by index) or the return (the +// `result` keyword) selects a function slot, and bracket steps (`[i]`) descend +// into generic arguments (enum args / `Box` pointee). For example, +// `#[thrust::refinement_path(result, [0])]` is the first type-argument of the +// return. // --------------------------------------------------------------------------- /// One step in a refinement's type-position path. @@ -1001,7 +1002,7 @@ fn refine_path_stmt(func: &FnItemWithSignature, r: &Refinement) -> TokenStream2 } }); quote! { - #[thrust::refine(#(#encoded_steps),*)] + #[thrust::refinement_path(#(#encoded_steps),*)] #path_prefix #name #turbofish; } } From bd081744f000605f96c6622a224882d3b9dbebfe Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 27 May 2026 12:34:27 +0000 Subject: [PATCH 05/18] Extract refinement-type annotation macros into refine module Move the param/ret/sig expansion logic and its token-scanning helpers out of the crate root into a dedicated refine module, leaving only the thin proc-macro entry points in lib.rs. Shrinks lib.rs from ~1232 to ~813 lines. https://claude.ai/code/session_01Km9xwaVaGQjnAy1iHsPewa --- thrust-macros/src/lib.rs | 430 +----------------------------------- thrust-macros/src/refine.rs | 428 +++++++++++++++++++++++++++++++++++ 2 files changed, 433 insertions(+), 425 deletions(-) create mode 100644 thrust-macros/src/refine.rs diff --git a/thrust-macros/src/lib.rs b/thrust-macros/src/lib.rs index ea27b3c7..debb22aa 100644 --- a/thrust-macros/src/lib.rs +++ b/thrust-macros/src/lib.rs @@ -1,5 +1,5 @@ use proc_macro::TokenStream; -use proc_macro2::{TokenStream as TokenStream2, TokenTree as TokenTree2}; +use proc_macro2::TokenStream as TokenStream2; use quote::{format_ident, quote, ToTokens}; use syn::{ parse_macro_input, punctuated::Punctuated, FnArg, GenericParam, Generics, TypeParamBound, @@ -7,6 +7,7 @@ use syn::{ }; mod invariant; +mod refine; #[derive(Debug, Clone)] enum FnOuterItem { @@ -571,440 +572,19 @@ impl ExpandedTokens { } } -// --------------------------------------------------------------------------- -// Refinement-type annotations: `param`, `ret`, `sig`. -// -// These lower refinement types (e.g. `List<{ v: i32 | v > 0 }>`) into -// `#[thrust::formula_fn]`s plus positioned `#[thrust::refinement_path(..)]` -// path statements injected into the function body. The "type position" -// addresses into the function type: a parameter (by index) or the return (the -// `result` keyword) selects a function slot, and bracket steps (`[i]`) descend -// into generic arguments (enum args / `Box` pointee). For example, -// `#[thrust::refinement_path(result, [0])]` is the first type-argument of the -// return. -// --------------------------------------------------------------------------- - -/// One step in a refinement's type-position path. -/// -/// Mirrors [`rty::TypePositionStep`] on the plugin side and uses the same -/// attribute encoding: -/// - [`Param`](Self::Param) / [`Return`](Self::Return) navigate into a function -/// type; encoded as an integer literal / the `result` keyword. -/// - [`TypeArg`](Self::TypeArg) navigates into a generic type argument; encoded -/// as a bracket group `[i]`. -#[derive(Clone, Copy)] -enum RefineStep { - Param(usize), - Return, - TypeArg(usize), -} - -#[derive(Clone)] -struct Refinement { - /// Full type-position path from the function root to the refined type. - steps: Vec, - binder: syn::Ident, - binder_ty: TokenStream2, - formula: TokenStream2, -} - #[proc_macro_attribute] pub fn param(attr: TokenStream, item: TokenStream) -> TokenStream { - expand_refine(RefineKind::Param, attr, item) + refine::expand_refine(refine::RefineKind::Param, attr, item) } #[proc_macro_attribute] pub fn ret(attr: TokenStream, item: TokenStream) -> TokenStream { - expand_refine(RefineKind::Ret, attr, item) + refine::expand_refine(refine::RefineKind::Ret, attr, item) } #[proc_macro_attribute] pub fn sig(attr: TokenStream, item: TokenStream) -> TokenStream { - expand_refine(RefineKind::Sig, attr, item) -} - -enum RefineKind { - Param, - Ret, - Sig, -} - -fn expand_refine(kind: RefineKind, attr: TokenStream, item: TokenStream) -> TokenStream { - let mut func = parse_macro_input!(item as FnItemWithSignature); - - let outer_context = match extract_outer_context(&func) { - Ok(ctx) => ctx, - Err(e) => { - let err = e.to_compile_error(); - return quote! { #err #func }.into(); - } - }; - - let attr_tokens: Vec = TokenStream2::from(attr).into_iter().collect(); - let jobs = match build_refine_jobs(kind, &func, &attr_tokens) { - Ok(jobs) => jobs, - Err(e) => { - let err = e.to_compile_error(); - return quote! { #err #func }.into(); - } - }; - - let mut refinements = Vec::new(); - for (root_steps, ty_tokens) in jobs { - if let Err(e) = scan_type(&ty_tokens, root_steps, &mut refinements) { - let err = e.to_compile_error(); - return quote! { #err #func }.into(); - } - } - - if refinements.is_empty() { - return func.into_token_stream().into(); - } - - let has_receiver = func.sig().receiver().is_some(); - let mut formula_fns = Vec::new(); - let mut path_stmts = Vec::new(); - for mut r in refinements { - if has_receiver { - r.formula = rewrite_self_in_tokens(r.formula); - } - formula_fns.push(refine_formula_fn(&func, outer_context.as_ref(), &r)); - path_stmts.push(refine_path_stmt(&func, &r)); - } - - let Some(block) = func.block_mut() else { - let err = syn::Error::new_spanned( - func.sig().ident.clone(), - "refinement-type annotations require a function body", - ) - .into_compile_error(); - return quote! { #err #func }.into(); - }; - let orig_stmts = block.stmts.drain(..).collect::>(); - *block = syn::parse_quote!({ - #(#path_stmts)* - #(#orig_stmts)* - }); - func.attrs_mut() - .push(syn::parse_quote!(#[allow(path_statements)])); - - quote! { - #(#formula_fns)* - #func - } - .into() -} - -/// Builds `(root_steps, type_tokens)` jobs to scan from the attribute tokens. -/// -/// Each job's `root_steps` contains the initial [`RefineStep`]s that fix the -/// position of the type expression within the function signature (e.g. -/// `[Param(0)]` for the first parameter). [`scan_type`] will append further -/// [`RefineStep::TypeArg`] steps as it descends into generic type arguments. -fn build_refine_jobs( - kind: RefineKind, - func: &FnItemWithSignature, - attr_tokens: &[TokenTree2], -) -> syn::Result, Vec)>> { - match kind { - RefineKind::Param => { - let (name, ty_tokens) = split_name_type(attr_tokens)?; - let idx = param_index(func, &name)?; - Ok(vec![(vec![RefineStep::Param(idx)], ty_tokens)]) - } - RefineKind::Ret => Ok(vec![(vec![RefineStep::Return], attr_tokens.to_vec())]), - RefineKind::Sig => { - let (args, ret_tokens) = parse_sig_attr(attr_tokens)?; - let mut jobs = Vec::new(); - for (name, ty_tokens) in args { - let idx = param_index(func, &name)?; - jobs.push((vec![RefineStep::Param(idx)], ty_tokens)); - } - jobs.push((vec![RefineStep::Return], ret_tokens)); - Ok(jobs) - } - } -} - -fn param_index(func: &FnItemWithSignature, name: &syn::Ident) -> syn::Result { - let pos = func.sig().inputs.iter().position(|arg| match arg { - FnArg::Receiver(_) => name == "self", - FnArg::Typed(pt) => matches!(&*pt.pat, syn::Pat::Ident(pi) if &pi.ident == name), - }); - pos.ok_or_else(|| { - syn::Error::new_spanned(name, format!("no parameter named `{}` in signature", name)) - }) -} - -/// Parses `name : ` from a flat token slice. -fn split_name_type(tokens: &[TokenTree2]) -> syn::Result<(syn::Ident, Vec)> { - let name = match tokens.first() { - Some(TokenTree2::Ident(id)) => id.clone(), - _ => return Err(err_tokens(tokens, "expected a parameter name")), - }; - match tokens.get(1) { - Some(TokenTree2::Punct(p)) if p.as_char() == ':' => {} - _ => return Err(err_tokens(tokens, "expected `:` after parameter name")), - } - Ok((name, tokens[2..].to_vec())) -} - -/// Parses `fn ( n0: t0 , ... ) -> ret` into `((name, ty_tokens)*, ret_tokens)`. -#[allow(clippy::type_complexity)] -fn parse_sig_attr( - tokens: &[TokenTree2], -) -> syn::Result<(Vec<(syn::Ident, Vec)>, Vec)> { - match tokens.first() { - Some(TokenTree2::Ident(id)) if id == "fn" => {} - _ => return Err(err_tokens(tokens, "expected `fn` in sig annotation")), - } - let arg_group = match tokens.get(1) { - Some(TokenTree2::Group(g)) if g.delimiter() == proc_macro2::Delimiter::Parenthesis => g, - _ => return Err(err_tokens(tokens, "expected `(..)` after `fn`")), - }; - - let mut args = Vec::new(); - let arg_tokens: Vec = arg_group.stream().into_iter().collect(); - for arg in split_top_level_commas(&arg_tokens) { - if arg.is_empty() { - continue; - } - args.push(split_name_type(&arg)?); - } - - // expect `->` then the return type - let mut rest = &tokens[2..]; - match (rest.first(), rest.get(1)) { - (Some(TokenTree2::Punct(a)), Some(TokenTree2::Punct(b))) - if a.as_char() == '-' && b.as_char() == '>' => - { - rest = &rest[2..]; - } - _ => { - return Err(err_tokens( - tokens, - "expected `->` and a return type in sig annotation", - )) - } - } - Ok((args, rest.to_vec())) -} - -/// Scans a type expression and records every refinement node with its full -/// type-position path (`steps`). -/// -/// `steps` holds the path from the function root to the current type node. -/// When a refinement `{binder: ty | formula}` is found the current `steps` are -/// recorded; when descending into generic type arguments a -/// [`RefineStep::TypeArg`]`(i)` step is appended to `steps`. -fn scan_type( - tokens: &[TokenTree2], - steps: Vec, - out: &mut Vec, -) -> syn::Result<()> { - if tokens.is_empty() { - return Ok(()); - } - - // A refinement node is exactly a brace-delimited group. - if tokens.len() == 1 { - if let TokenTree2::Group(g) = &tokens[0] { - if g.delimiter() == proc_macro2::Delimiter::Brace { - let (binder, binder_ty, formula) = split_refinement(g.stream())?; - out.push(Refinement { - steps: steps.clone(), - binder, - binder_ty: binder_ty.iter().cloned().collect(), - formula, - }); - // Descend into the binder type for nested refinements. - scan_type(&binder_ty, steps, out)?; - return Ok(()); - } - } - } - - // A nominal type `Name` (`Box` included). - if let TokenTree2::Ident(_) = &tokens[0] { - if let Some(TokenTree2::Punct(p)) = tokens.get(1) { - if p.as_char() == '<' { - let mut type_idx = 0; - for arg in split_angle_args(&tokens[2..]) { - if is_lifetime(&arg) { - continue; - } - let mut child = steps.clone(); - child.push(RefineStep::TypeArg(type_idx)); - scan_type(&arg, child, out)?; - type_idx += 1; - } - } - } - } - - Ok(()) -} - -/// Splits `{ binder : ty | formula }` contents into its parts. -fn split_refinement( - stream: TokenStream2, -) -> syn::Result<(syn::Ident, Vec, TokenStream2)> { - let toks: Vec = stream.into_iter().collect(); - let bar = toks - .iter() - .position(|tt| matches!(tt, TokenTree2::Punct(p) if p.as_char() == '|')) - .ok_or_else(|| err_tokens(&toks, "refinement type must contain `|`"))?; - let (binder, binder_ty) = split_name_type(&toks[..bar])?; - let formula: TokenStream2 = toks[bar + 1..].iter().cloned().collect(); - Ok((binder, binder_ty, formula)) -} - -/// Splits the tokens following an opening `<` at top level by commas, stopping -/// at the matching `>`. -fn split_angle_args(tokens: &[TokenTree2]) -> Vec> { - let mut args = Vec::new(); - let mut cur = Vec::new(); - let mut depth = 1usize; - for tt in tokens { - if let TokenTree2::Punct(p) = tt { - match p.as_char() { - '<' => { - depth += 1; - cur.push(tt.clone()); - continue; - } - '>' => { - depth -= 1; - if depth == 0 { - break; - } - cur.push(tt.clone()); - continue; - } - ',' if depth == 1 => { - args.push(std::mem::take(&mut cur)); - continue; - } - _ => {} - } - } - cur.push(tt.clone()); - } - if !cur.is_empty() { - args.push(cur); - } - args -} - -fn split_top_level_commas(tokens: &[TokenTree2]) -> Vec> { - let mut out = Vec::new(); - let mut cur = Vec::new(); - let mut depth = 0i32; - for tt in tokens { - if let TokenTree2::Punct(p) = tt { - match p.as_char() { - '<' => depth += 1, - '>' => depth -= 1, - ',' if depth == 0 => { - out.push(std::mem::take(&mut cur)); - continue; - } - _ => {} - } - } - cur.push(tt.clone()); - } - out.push(cur); - out -} - -fn is_lifetime(tokens: &[TokenTree2]) -> bool { - matches!(tokens.first(), Some(TokenTree2::Punct(p)) if p.as_char() == '\'') -} - -fn err_tokens(tokens: &[TokenTree2], msg: &str) -> syn::Error { - let stream: TokenStream2 = tokens.iter().cloned().collect(); - syn::Error::new_spanned(stream, msg) -} - -fn rewrite_self_in_tokens(tokens: TokenStream2) -> TokenStream2 { - tokens - .into_iter() - .map(|tt| match tt { - TokenTree2::Ident(id) if id == "self" => TokenTree2::Ident(format_ident!("self_")), - TokenTree2::Group(g) => { - let inner = rewrite_self_in_tokens(g.stream()); - TokenTree2::Group(proc_macro2::Group::new(g.delimiter(), inner)) - } - other => other, - }) - .collect() -} - -fn refine_fn_name(func: &FnItemWithSignature, r: &Refinement) -> syn::Ident { - let pos = r - .steps - .iter() - .map(|s| match s { - RefineStep::Param(i) => format!("p{}", i), - RefineStep::Return => "ret".to_string(), - RefineStep::TypeArg(i) => format!("t{}", i), - }) - .collect::>() - .join("_"); - format_ident!("_thrust_refine_{}_{}", func.sig().ident, pos) -} - -fn refine_formula_fn( - func: &FnItemWithSignature, - outer_context: Option<&FnOuterItem>, - r: &Refinement, -) -> TokenStream2 { - let name = refine_fn_name(func, r); - let def_generics = generic_params_tokens(&func.sig().generics); - let model_params = fn_params_with_model_ty(&func.sig().inputs); - let model_preds = model_where_predicates(func, outer_context); - let extended_where = extended_where_clause(func, &model_preds); - let binder = &r.binder; - let binder_ty = &r.binder_ty; - let formula = &r.formula; - - quote! { - #[allow(unused_variables)] - #[allow(non_snake_case)] - #[thrust::formula_fn] - fn #name #def_generics( - #binder: <#binder_ty as thrust_models::Model>::Ty, - #model_params - ) -> bool #extended_where { - #formula - } - } -} - -fn refine_path_stmt(func: &FnItemWithSignature, r: &Refinement) -> TokenStream2 { - let name = refine_fn_name(func, r); - let turbofish = generic_turbofish(&func.sig().generics); - let path_prefix = if func.sig().receiver().is_some() { - quote!(Self::) - } else { - quote!() - }; - let encoded_steps = r.steps.iter().map(|s| match s { - RefineStep::Param(i) => { - let lit = proc_macro2::Literal::usize_unsuffixed(*i); - quote!(#lit) - } - RefineStep::Return => quote!(result), - RefineStep::TypeArg(i) => { - let lit = proc_macro2::Literal::usize_unsuffixed(*i); - quote!([#lit]) - } - }); - quote! { - #[thrust::refinement_path(#(#encoded_steps),*)] - #path_prefix #name #turbofish; - } + refine::expand_refine(refine::RefineKind::Sig, attr, item) } fn mentions_self(sig: &syn::Signature) -> bool { diff --git a/thrust-macros/src/refine.rs b/thrust-macros/src/refine.rs new file mode 100644 index 00000000..9199e660 --- /dev/null +++ b/thrust-macros/src/refine.rs @@ -0,0 +1,428 @@ +//! Refinement-type annotations: `param`, `ret`, `sig`. +//! +//! These lower refinement types (e.g. `List<{ v: i32 | v > 0 }>`) into +//! `#[thrust::formula_fn]`s plus positioned `#[thrust::refinement_path(..)]` +//! path statements injected into the function body. The "type position" +//! addresses into the function type: a parameter (by index) or the return (the +//! `result` keyword) selects a function slot, and bracket steps (`[i]`) descend +//! into generic arguments (enum args / `Box` pointee). For example, +//! `#[thrust::refinement_path(result, [0])]` is the first type-argument of the +//! return. + +use proc_macro::TokenStream; +use proc_macro2::{TokenStream as TokenStream2, TokenTree as TokenTree2}; +use quote::{format_ident, quote, ToTokens}; +use syn::{parse_macro_input, FnArg}; + +use super::{ + extended_where_clause, extract_outer_context, fn_params_with_model_ty, generic_params_tokens, + generic_turbofish, model_where_predicates, FnItemWithSignature, FnOuterItem, +}; + +/// One step in a refinement's type-position path. +/// +/// Mirrors [`rty::TypePositionStep`] on the plugin side and uses the same +/// attribute encoding: +/// - [`Param`](Self::Param) / [`Return`](Self::Return) navigate into a function +/// type; encoded as an integer literal / the `result` keyword. +/// - [`TypeArg`](Self::TypeArg) navigates into a generic type argument; encoded +/// as a bracket group `[i]`. +#[derive(Clone, Copy)] +enum RefineStep { + Param(usize), + Return, + TypeArg(usize), +} + +#[derive(Clone)] +struct Refinement { + /// Full type-position path from the function root to the refined type. + steps: Vec, + binder: syn::Ident, + binder_ty: TokenStream2, + formula: TokenStream2, +} + +pub(crate) enum RefineKind { + Param, + Ret, + Sig, +} + +pub(crate) fn expand_refine(kind: RefineKind, attr: TokenStream, item: TokenStream) -> TokenStream { + let mut func = parse_macro_input!(item as FnItemWithSignature); + + let outer_context = match extract_outer_context(&func) { + Ok(ctx) => ctx, + Err(e) => { + let err = e.to_compile_error(); + return quote! { #err #func }.into(); + } + }; + + let attr_tokens: Vec = TokenStream2::from(attr).into_iter().collect(); + let jobs = match build_refine_jobs(kind, &func, &attr_tokens) { + Ok(jobs) => jobs, + Err(e) => { + let err = e.to_compile_error(); + return quote! { #err #func }.into(); + } + }; + + let mut refinements = Vec::new(); + for (root_steps, ty_tokens) in jobs { + if let Err(e) = scan_type(&ty_tokens, root_steps, &mut refinements) { + let err = e.to_compile_error(); + return quote! { #err #func }.into(); + } + } + + if refinements.is_empty() { + return func.into_token_stream().into(); + } + + let has_receiver = func.sig().receiver().is_some(); + let mut formula_fns = Vec::new(); + let mut path_stmts = Vec::new(); + for mut r in refinements { + if has_receiver { + r.formula = rewrite_self_in_tokens(r.formula); + } + formula_fns.push(refine_formula_fn(&func, outer_context.as_ref(), &r)); + path_stmts.push(refine_path_stmt(&func, &r)); + } + + let Some(block) = func.block_mut() else { + let err = syn::Error::new_spanned( + func.sig().ident.clone(), + "refinement-type annotations require a function body", + ) + .into_compile_error(); + return quote! { #err #func }.into(); + }; + let orig_stmts = block.stmts.drain(..).collect::>(); + *block = syn::parse_quote!({ + #(#path_stmts)* + #(#orig_stmts)* + }); + func.attrs_mut() + .push(syn::parse_quote!(#[allow(path_statements)])); + + quote! { + #(#formula_fns)* + #func + } + .into() +} + +/// Builds `(root_steps, type_tokens)` jobs to scan from the attribute tokens. +/// +/// Each job's `root_steps` contains the initial [`RefineStep`]s that fix the +/// position of the type expression within the function signature (e.g. +/// `[Param(0)]` for the first parameter). [`scan_type`] will append further +/// [`RefineStep::TypeArg`] steps as it descends into generic type arguments. +fn build_refine_jobs( + kind: RefineKind, + func: &FnItemWithSignature, + attr_tokens: &[TokenTree2], +) -> syn::Result, Vec)>> { + match kind { + RefineKind::Param => { + let (name, ty_tokens) = split_name_type(attr_tokens)?; + let idx = param_index(func, &name)?; + Ok(vec![(vec![RefineStep::Param(idx)], ty_tokens)]) + } + RefineKind::Ret => Ok(vec![(vec![RefineStep::Return], attr_tokens.to_vec())]), + RefineKind::Sig => { + let (args, ret_tokens) = parse_sig_attr(attr_tokens)?; + let mut jobs = Vec::new(); + for (name, ty_tokens) in args { + let idx = param_index(func, &name)?; + jobs.push((vec![RefineStep::Param(idx)], ty_tokens)); + } + jobs.push((vec![RefineStep::Return], ret_tokens)); + Ok(jobs) + } + } +} + +fn param_index(func: &FnItemWithSignature, name: &syn::Ident) -> syn::Result { + let pos = func.sig().inputs.iter().position(|arg| match arg { + FnArg::Receiver(_) => name == "self", + FnArg::Typed(pt) => matches!(&*pt.pat, syn::Pat::Ident(pi) if &pi.ident == name), + }); + pos.ok_or_else(|| { + syn::Error::new_spanned(name, format!("no parameter named `{}` in signature", name)) + }) +} + +/// Parses `name : ` from a flat token slice. +fn split_name_type(tokens: &[TokenTree2]) -> syn::Result<(syn::Ident, Vec)> { + let name = match tokens.first() { + Some(TokenTree2::Ident(id)) => id.clone(), + _ => return Err(err_tokens(tokens, "expected a parameter name")), + }; + match tokens.get(1) { + Some(TokenTree2::Punct(p)) if p.as_char() == ':' => {} + _ => return Err(err_tokens(tokens, "expected `:` after parameter name")), + } + Ok((name, tokens[2..].to_vec())) +} + +/// Parses `fn ( n0: t0 , ... ) -> ret` into `((name, ty_tokens)*, ret_tokens)`. +#[allow(clippy::type_complexity)] +fn parse_sig_attr( + tokens: &[TokenTree2], +) -> syn::Result<(Vec<(syn::Ident, Vec)>, Vec)> { + match tokens.first() { + Some(TokenTree2::Ident(id)) if id == "fn" => {} + _ => return Err(err_tokens(tokens, "expected `fn` in sig annotation")), + } + let arg_group = match tokens.get(1) { + Some(TokenTree2::Group(g)) if g.delimiter() == proc_macro2::Delimiter::Parenthesis => g, + _ => return Err(err_tokens(tokens, "expected `(..)` after `fn`")), + }; + + let mut args = Vec::new(); + let arg_tokens: Vec = arg_group.stream().into_iter().collect(); + for arg in split_top_level_commas(&arg_tokens) { + if arg.is_empty() { + continue; + } + args.push(split_name_type(&arg)?); + } + + // expect `->` then the return type + let mut rest = &tokens[2..]; + match (rest.first(), rest.get(1)) { + (Some(TokenTree2::Punct(a)), Some(TokenTree2::Punct(b))) + if a.as_char() == '-' && b.as_char() == '>' => + { + rest = &rest[2..]; + } + _ => { + return Err(err_tokens( + tokens, + "expected `->` and a return type in sig annotation", + )) + } + } + Ok((args, rest.to_vec())) +} + +/// Scans a type expression and records every refinement node with its full +/// type-position path (`steps`). +/// +/// `steps` holds the path from the function root to the current type node. +/// When a refinement `{binder: ty | formula}` is found the current `steps` are +/// recorded; when descending into generic type arguments a +/// [`RefineStep::TypeArg`]`(i)` step is appended to `steps`. +fn scan_type( + tokens: &[TokenTree2], + steps: Vec, + out: &mut Vec, +) -> syn::Result<()> { + if tokens.is_empty() { + return Ok(()); + } + + // A refinement node is exactly a brace-delimited group. + if tokens.len() == 1 { + if let TokenTree2::Group(g) = &tokens[0] { + if g.delimiter() == proc_macro2::Delimiter::Brace { + let (binder, binder_ty, formula) = split_refinement(g.stream())?; + out.push(Refinement { + steps: steps.clone(), + binder, + binder_ty: binder_ty.iter().cloned().collect(), + formula, + }); + // Descend into the binder type for nested refinements. + scan_type(&binder_ty, steps, out)?; + return Ok(()); + } + } + } + + // A nominal type `Name` (`Box` included). + if let TokenTree2::Ident(_) = &tokens[0] { + if let Some(TokenTree2::Punct(p)) = tokens.get(1) { + if p.as_char() == '<' { + let mut type_idx = 0; + for arg in split_angle_args(&tokens[2..]) { + if is_lifetime(&arg) { + continue; + } + let mut child = steps.clone(); + child.push(RefineStep::TypeArg(type_idx)); + scan_type(&arg, child, out)?; + type_idx += 1; + } + } + } + } + + Ok(()) +} + +/// Splits `{ binder : ty | formula }` contents into its parts. +fn split_refinement( + stream: TokenStream2, +) -> syn::Result<(syn::Ident, Vec, TokenStream2)> { + let toks: Vec = stream.into_iter().collect(); + let bar = toks + .iter() + .position(|tt| matches!(tt, TokenTree2::Punct(p) if p.as_char() == '|')) + .ok_or_else(|| err_tokens(&toks, "refinement type must contain `|`"))?; + let (binder, binder_ty) = split_name_type(&toks[..bar])?; + let formula: TokenStream2 = toks[bar + 1..].iter().cloned().collect(); + Ok((binder, binder_ty, formula)) +} + +/// Splits the tokens following an opening `<` at top level by commas, stopping +/// at the matching `>`. +fn split_angle_args(tokens: &[TokenTree2]) -> Vec> { + let mut args = Vec::new(); + let mut cur = Vec::new(); + let mut depth = 1usize; + for tt in tokens { + if let TokenTree2::Punct(p) = tt { + match p.as_char() { + '<' => { + depth += 1; + cur.push(tt.clone()); + continue; + } + '>' => { + depth -= 1; + if depth == 0 { + break; + } + cur.push(tt.clone()); + continue; + } + ',' if depth == 1 => { + args.push(std::mem::take(&mut cur)); + continue; + } + _ => {} + } + } + cur.push(tt.clone()); + } + if !cur.is_empty() { + args.push(cur); + } + args +} + +fn split_top_level_commas(tokens: &[TokenTree2]) -> Vec> { + let mut out = Vec::new(); + let mut cur = Vec::new(); + let mut depth = 0i32; + for tt in tokens { + if let TokenTree2::Punct(p) = tt { + match p.as_char() { + '<' => depth += 1, + '>' => depth -= 1, + ',' if depth == 0 => { + out.push(std::mem::take(&mut cur)); + continue; + } + _ => {} + } + } + cur.push(tt.clone()); + } + out.push(cur); + out +} + +fn is_lifetime(tokens: &[TokenTree2]) -> bool { + matches!(tokens.first(), Some(TokenTree2::Punct(p)) if p.as_char() == '\'') +} + +fn err_tokens(tokens: &[TokenTree2], msg: &str) -> syn::Error { + let stream: TokenStream2 = tokens.iter().cloned().collect(); + syn::Error::new_spanned(stream, msg) +} + +fn rewrite_self_in_tokens(tokens: TokenStream2) -> TokenStream2 { + tokens + .into_iter() + .map(|tt| match tt { + TokenTree2::Ident(id) if id == "self" => TokenTree2::Ident(format_ident!("self_")), + TokenTree2::Group(g) => { + let inner = rewrite_self_in_tokens(g.stream()); + TokenTree2::Group(proc_macro2::Group::new(g.delimiter(), inner)) + } + other => other, + }) + .collect() +} + +fn refine_fn_name(func: &FnItemWithSignature, r: &Refinement) -> syn::Ident { + let pos = r + .steps + .iter() + .map(|s| match s { + RefineStep::Param(i) => format!("p{}", i), + RefineStep::Return => "ret".to_string(), + RefineStep::TypeArg(i) => format!("t{}", i), + }) + .collect::>() + .join("_"); + format_ident!("_thrust_refine_{}_{}", func.sig().ident, pos) +} + +fn refine_formula_fn( + func: &FnItemWithSignature, + outer_context: Option<&FnOuterItem>, + r: &Refinement, +) -> TokenStream2 { + let name = refine_fn_name(func, r); + let def_generics = generic_params_tokens(&func.sig().generics); + let model_params = fn_params_with_model_ty(&func.sig().inputs); + let model_preds = model_where_predicates(func, outer_context); + let extended_where = extended_where_clause(func, &model_preds); + let binder = &r.binder; + let binder_ty = &r.binder_ty; + let formula = &r.formula; + + quote! { + #[allow(unused_variables)] + #[allow(non_snake_case)] + #[thrust::formula_fn] + fn #name #def_generics( + #binder: <#binder_ty as thrust_models::Model>::Ty, + #model_params + ) -> bool #extended_where { + #formula + } + } +} + +fn refine_path_stmt(func: &FnItemWithSignature, r: &Refinement) -> TokenStream2 { + let name = refine_fn_name(func, r); + let turbofish = generic_turbofish(&func.sig().generics); + let path_prefix = if func.sig().receiver().is_some() { + quote!(Self::) + } else { + quote!() + }; + let encoded_steps = r.steps.iter().map(|s| match s { + RefineStep::Param(i) => { + let lit = proc_macro2::Literal::usize_unsuffixed(*i); + quote!(#lit) + } + RefineStep::Return => quote!(result), + RefineStep::TypeArg(i) => { + let lit = proc_macro2::Literal::usize_unsuffixed(*i); + quote!([#lit]) + } + }); + quote! { + #[thrust::refinement_path(#(#encoded_steps),*)] + #path_prefix #name #turbofish; + } +} From e11a748786ab2bae3cc25fb6182e3160e290d7d2 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 27 May 2026 15:30:34 +0000 Subject: [PATCH 06/18] Use $i / bare integer syntax for refinement_path type positions Encode function parameters as $i (matching FunctionParamIdx's Display) and type arguments as bare integers, instead of bare integers for params and bracketed [i] for type args. Reads more naturally and keeps the attribute syntax consistent with how parameters are displayed elsewhere. https://claude.ai/code/session_01Km9xwaVaGQjnAy1iHsPewa --- src/analyze/annot.rs | 60 +++++++++++++++++-------------------- src/rty.rs | 14 ++++----- thrust-macros/src/refine.rs | 16 +++++----- 3 files changed, 42 insertions(+), 48 deletions(-) diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index 5ea243c0..00c5238d 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -225,16 +225,16 @@ pub fn extract_annot_tokens(attr: Attribute) -> TokenStream { /// Tokens are comma-separated steps. Each step is one of: /// - The keyword `result` → [`rty::TypePositionStep::Return`] (navigate to a /// function type's return slot). -/// - An integer literal `i` → [`rty::TypePositionStep::Param`]`(i)` (navigate -/// to the `i`-th parameter of a function type). -/// - A bracket group `[i]` → [`rty::TypePositionStep::TypeArg`]`(i)` (navigate -/// to the `i`-th type argument of a generic type such as an enum or `Box`). +/// - `$i` (a `$` followed by an integer) → [`rty::TypePositionStep::Param`]`(i)` +/// (navigate to the `i`-th parameter of a function type). +/// - A bare integer `i` → [`rty::TypePositionStep::TypeArg`]`(i)` (navigate to +/// the `i`-th type argument of a generic type such as an enum or `Box`). /// -/// Examples: `result` is the return; `0` is the first parameter; `0, [0]` is -/// the first type-argument of the first parameter; `0, result` is the return of -/// a function-typed first parameter. +/// Examples: `result` is the return; `$0` is the first parameter; `$0, 0` is +/// the first type-argument of the first parameter; `$0, result` is the return +/// of a function-typed first parameter. pub fn parse_type_position(ts: &TokenStream) -> rty::TypePosition { - use rustc_ast::token::{Delimiter, LitKind, TokenKind}; + use rustc_ast::token::{LitKind, TokenKind}; use rustc_ast::tokenstream::TokenTree; let parse_int = |lit: &rustc_ast::token::Lit| -> usize { @@ -250,36 +250,30 @@ pub fn parse_type_position(ts: &TokenStream) -> rty::TypePosition { }; let mut steps = Vec::new(); - for tt in ts.iter() { - match tt { - TokenTree::Token(t, _) => match &t.kind { - TokenKind::Comma => {} - TokenKind::Ident(sym, _) if sym.as_str() == "result" => { - steps.push(rty::TypePositionStep::Return); - } - TokenKind::Literal(lit) => { - steps.push(rty::TypePositionStep::Param(rty::FunctionParamIdx::from( - parse_int(lit), - ))); - } - _ => panic!("unexpected token in refine position: {:?}", t), - }, - TokenTree::Delimited(_, _, Delimiter::Bracket, inner) => { - let mut inner_iter = inner.iter(); - let i = match inner_iter.next() { + let mut iter = ts.iter(); + while let Some(tt) = iter.next() { + let TokenTree::Token(t, _) = tt else { + panic!("unexpected token tree in refine position"); + }; + match &t.kind { + TokenKind::Comma => {} + TokenKind::Ident(sym, _) if sym.as_str() == "result" => { + steps.push(rty::TypePositionStep::Return); + } + TokenKind::Dollar => { + let i = match iter.next() { Some(TokenTree::Token(t, _)) => match &t.kind { TokenKind::Literal(lit) => parse_int(lit), - _ => panic!("expected integer inside [..] refine step: {:?}", t), + _ => panic!("expected integer after `$` in refine position: {:?}", t), }, - _ => panic!("expected integer inside [..] refine step"), + _ => panic!("expected integer after `$` in refine position"), }; - assert!( - inner_iter.next().is_none(), - "expected exactly one integer inside [..] refine step" - ); - steps.push(rty::TypePositionStep::TypeArg(i)); + steps.push(rty::TypePositionStep::Param(rty::FunctionParamIdx::from(i))); + } + TokenKind::Literal(lit) => { + steps.push(rty::TypePositionStep::TypeArg(parse_int(lit))); } - _ => panic!("unexpected token tree in refine position"), + _ => panic!("unexpected token in refine position: {:?}", t), } } diff --git a/src/rty.rs b/src/rty.rs index b0f10f3d..9466eaac 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -95,7 +95,7 @@ where /// Using distinct variants for function navigation ([`Param`](Self::Param), /// [`Return`](Self::Return)) and generic-arg navigation /// ([`TypeArg`](Self::TypeArg)) allows the same path representation to address -/// positions inside higher-order function types. For example, `[$0, result]` +/// positions inside higher-order function types. For example, `$0.result` /// addresses the return type of a function-typed first parameter. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum TypePositionStep { @@ -112,7 +112,7 @@ impl std::fmt::Display for TypePositionStep { match self { TypePositionStep::Param(idx) => write!(f, "{}", idx), TypePositionStep::Return => f.write_str("result"), - TypePositionStep::TypeArg(i) => write!(f, "[{}]", i), + TypePositionStep::TypeArg(i) => write!(f, "{}", i), } } } @@ -128,11 +128,11 @@ impl std::fmt::Display for TypePositionStep { /// types), enabling positions inside higher-order function types. /// /// Examples (function `fn f(x: List) -> Box`): -/// - `[$0]` — parameter `x`. -/// - `[result]` — the return type. -/// - `[$0, [0]]` — the first type arg of `x`. -/// - `[result, [0]]` — the pointee of the `Box` return. -/// - `[$0, result]` — the return of a function-typed param `x`. +/// - `$0` — parameter `x`. +/// - `result` — the return type. +/// - `$0.0` — the first type arg of `x`. +/// - `result.0` — the pointee of the `Box` return. +/// - `$0.result` — the return of a function-typed param `x`. #[derive(Debug, Clone, PartialEq, Eq)] pub struct TypePosition { steps: Vec, diff --git a/thrust-macros/src/refine.rs b/thrust-macros/src/refine.rs index 9199e660..06353817 100644 --- a/thrust-macros/src/refine.rs +++ b/thrust-macros/src/refine.rs @@ -3,10 +3,10 @@ //! These lower refinement types (e.g. `List<{ v: i32 | v > 0 }>`) into //! `#[thrust::formula_fn]`s plus positioned `#[thrust::refinement_path(..)]` //! path statements injected into the function body. The "type position" -//! addresses into the function type: a parameter (by index) or the return (the -//! `result` keyword) selects a function slot, and bracket steps (`[i]`) descend -//! into generic arguments (enum args / `Box` pointee). For example, -//! `#[thrust::refinement_path(result, [0])]` is the first type-argument of the +//! addresses into the function type: a parameter (`$i`) or the return (the +//! `result` keyword) selects a function slot, and bare integer steps (`i`) +//! descend into generic arguments (enum args / `Box` pointee). For example, +//! `#[thrust::refinement_path(result, 0)]` is the first type-argument of the //! return. use proc_macro::TokenStream; @@ -24,9 +24,9 @@ use super::{ /// Mirrors [`rty::TypePositionStep`] on the plugin side and uses the same /// attribute encoding: /// - [`Param`](Self::Param) / [`Return`](Self::Return) navigate into a function -/// type; encoded as an integer literal / the `result` keyword. +/// type; encoded as `$i` / the `result` keyword. /// - [`TypeArg`](Self::TypeArg) navigates into a generic type argument; encoded -/// as a bracket group `[i]`. +/// as a bare integer `i`. #[derive(Clone, Copy)] enum RefineStep { Param(usize), @@ -413,12 +413,12 @@ fn refine_path_stmt(func: &FnItemWithSignature, r: &Refinement) -> TokenStream2 let encoded_steps = r.steps.iter().map(|s| match s { RefineStep::Param(i) => { let lit = proc_macro2::Literal::usize_unsuffixed(*i); - quote!(#lit) + quote!($#lit) } RefineStep::Return => quote!(result), RefineStep::TypeArg(i) => { let lit = proc_macro2::Literal::usize_unsuffixed(*i); - quote!([#lit]) + quote!(#lit) } }); quote! { From 8123bd83c7d17423427c1e259dfbbf38f1f8c6b1 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 27 May 2026 15:47:38 +0000 Subject: [PATCH 07/18] Clarify refinement type-position naming and align Display with syntax - Rename FunctionTemplateTypeBuilder::refine to install_refinement_at, matching the RefinedType method it delegates to. - Rename extract_refine_paths / extract_refine_annots to extract_refinement_paths / extract_refinement_annots, and clarify the related panic messages, to reduce the overloaded use of "refine". - Make TypePosition's Display match the refinement_path(..) surface syntax (comma-separated steps) instead of dot-separated. - Replace the macro's build_refine_jobs (and its complex tuple return type) with annotated_type_exprs returning PositionedTypeExpr, and introduce NamedType / SigAnnotation structs so parse_sig_attr no longer needs an allow(clippy::type_complexity). Rename RefineStep to PositionStep, RefineKind to AnnotationKind, and expand_refine to expand for clarity. https://claude.ai/code/session_01Km9xwaVaGQjnAy1iHsPewa --- src/analyze.rs | 15 ++-- src/analyze/local_def.rs | 8 +- src/refine/template.rs | 2 +- src/rty.rs | 10 +-- thrust-macros/src/lib.rs | 6 +- thrust-macros/src/refine.rs | 149 +++++++++++++++++++++--------------- 6 files changed, 111 insertions(+), 79 deletions(-) diff --git a/src/analyze.rs b/src/analyze.rs index e16fb08b..e7420df9 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -792,7 +792,10 @@ impl<'tcx> Analyzer<'tcx> { /// Collects every `#[thrust::refinement_path(..)]` path statement in the /// function body, returning each `(type position, formula_fn DefId)`. - fn extract_refine_paths(&self, local_def_id: LocalDefId) -> Vec<(rty::TypePosition, DefId)> { + fn extract_refinement_paths( + &self, + local_def_id: LocalDefId, + ) -> Vec<(rty::TypePosition, DefId)> { let mut out = Vec::new(); let Some(body) = self.tcx.hir_maybe_body_owned_by(local_def_id) else { return out; @@ -842,27 +845,27 @@ impl<'tcx> Analyzer<'tcx> { /// Resolves every `#[thrust::refinement_path(..)]` annotation into a /// positioned refinement, by translating the referenced formula function. - pub fn extract_refine_annots( + pub fn extract_refinement_annots( &self, local_def_id: LocalDefId, generic_args: mir_ty::GenericArgsRef<'tcx>, ) -> Vec<(rty::TypePosition, rty::Refinement)> { let mut out = Vec::new(); - for (position, def_id) in self.extract_refine_paths(local_def_id) { + for (position, def_id) in self.extract_refinement_paths(local_def_id) { let Some(formula_def_id) = def_id.as_local() else { panic!( - "refine annotation with path is expected to refer to a local def, but found: {:?}", + "refinement_path annotation is expected to refer to a local def, but found: {:?}", def_id ); }; let Some(formula_fn) = self.formula_fn_with_args(formula_def_id, generic_args) else { panic!( - "refine annotation {:?} is not a formula function", + "refinement_path annotation {:?} is not a formula function", formula_def_id ); }; let AnnotFormula::Formula(formula) = formula_fn.to_ensure_annot() else { - panic!("refine annotation must lower to a plain formula"); + panic!("refinement_path annotation must lower to a plain formula"); }; out.push((position, formula.into())); } diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index 5fd6f6ce..b3472cb6 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -406,9 +406,9 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { assert!(require_annot.is_none() || param_annots.is_empty()); assert!(ensure_annot.is_none() || ret_annot.is_none()); - let refine_annots = self + let refinement_annots = self .ctx - .extract_refine_annots(self.local_def_id, self.generic_args); + .extract_refinement_annots(self.local_def_id, self.generic_args); let trait_item_ty = self.trait_item_ty(); let is_fully_annotated = self.is_fully_annotated(); @@ -435,8 +435,8 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { if let Some(ret_rty) = ret_annot { builder.ret_rty(ret_rty); } - for (position, refinement) in refine_annots { - builder.refine(&position, refinement); + for (position, refinement) in refinement_annots { + builder.install_refinement_at(&position, refinement); } if is_fully_annotated { diff --git a/src/refine/template.rs b/src/refine/template.rs index 50408612..e88008a6 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -571,7 +571,7 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> { /// The first step must be [`rty::TypePositionStep::Param`] or /// [`rty::TypePositionStep::Return`]; the remaining steps are forwarded to /// [`rty::RefinedType::install_refinement_at`]. - pub fn refine( + pub fn install_refinement_at( &mut self, position: &rty::TypePosition, refinement: rty::Refinement, diff --git a/src/rty.rs b/src/rty.rs index 9466eaac..c0c4efd7 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -95,7 +95,7 @@ where /// Using distinct variants for function navigation ([`Param`](Self::Param), /// [`Return`](Self::Return)) and generic-arg navigation /// ([`TypeArg`](Self::TypeArg)) allows the same path representation to address -/// positions inside higher-order function types. For example, `$0.result` +/// positions inside higher-order function types. For example, `$0, result` /// addresses the return type of a function-typed first parameter. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum TypePositionStep { @@ -130,9 +130,9 @@ impl std::fmt::Display for TypePositionStep { /// Examples (function `fn f(x: List) -> Box`): /// - `$0` — parameter `x`. /// - `result` — the return type. -/// - `$0.0` — the first type arg of `x`. -/// - `result.0` — the pointee of the `Box` return. -/// - `$0.result` — the return of a function-typed param `x`. +/// - `$0, 0` — the first type arg of `x`. +/// - `result, 0` — the pointee of the `Box` return. +/// - `$0, result` — the return of a function-typed param `x`. #[derive(Debug, Clone, PartialEq, Eq)] pub struct TypePosition { steps: Vec, @@ -157,7 +157,7 @@ impl std::fmt::Display for TypePosition { write!(f, "{}", first)?; } for s in iter { - write!(f, ".{}", s)?; + write!(f, ", {}", s)?; } Ok(()) } diff --git a/thrust-macros/src/lib.rs b/thrust-macros/src/lib.rs index debb22aa..a90ede1a 100644 --- a/thrust-macros/src/lib.rs +++ b/thrust-macros/src/lib.rs @@ -574,17 +574,17 @@ impl ExpandedTokens { #[proc_macro_attribute] pub fn param(attr: TokenStream, item: TokenStream) -> TokenStream { - refine::expand_refine(refine::RefineKind::Param, attr, item) + refine::expand(refine::AnnotationKind::Param, attr, item) } #[proc_macro_attribute] pub fn ret(attr: TokenStream, item: TokenStream) -> TokenStream { - refine::expand_refine(refine::RefineKind::Ret, attr, item) + refine::expand(refine::AnnotationKind::Ret, attr, item) } #[proc_macro_attribute] pub fn sig(attr: TokenStream, item: TokenStream) -> TokenStream { - refine::expand_refine(refine::RefineKind::Sig, attr, item) + refine::expand(refine::AnnotationKind::Sig, attr, item) } fn mentions_self(sig: &syn::Signature) -> bool { diff --git a/thrust-macros/src/refine.rs b/thrust-macros/src/refine.rs index 06353817..102f53e4 100644 --- a/thrust-macros/src/refine.rs +++ b/thrust-macros/src/refine.rs @@ -28,7 +28,7 @@ use super::{ /// - [`TypeArg`](Self::TypeArg) navigates into a generic type argument; encoded /// as a bare integer `i`. #[derive(Clone, Copy)] -enum RefineStep { +enum PositionStep { Param(usize), Return, TypeArg(usize), @@ -37,19 +37,37 @@ enum RefineStep { #[derive(Clone)] struct Refinement { /// Full type-position path from the function root to the refined type. - steps: Vec, + steps: Vec, binder: syn::Ident, binder_ty: TokenStream2, formula: TokenStream2, } -pub(crate) enum RefineKind { +/// Which refinement-type annotation is being expanded. +pub(crate) enum AnnotationKind { Param, Ret, Sig, } -pub(crate) fn expand_refine(kind: RefineKind, attr: TokenStream, item: TokenStream) -> TokenStream { +/// A type expression from the annotation, paired with the position of its root +/// within the function signature. [`scan_type`] walks each one to extract the +/// refinements it contains. +struct PositionedTypeExpr { + /// Steps locating the root of `tokens` (e.g. `[Param(0)]` for the first + /// parameter); [`scan_type`] appends `TypeArg` steps as it descends. + root: Vec, + tokens: Vec, +} + +/// A `name : type` binding, e.g. a parameter in a `sig` annotation or the +/// binder of a refinement `{ name: type | .. }`. +struct NamedType { + name: syn::Ident, + tokens: Vec, +} + +pub(crate) fn expand(kind: AnnotationKind, attr: TokenStream, item: TokenStream) -> TokenStream { let mut func = parse_macro_input!(item as FnItemWithSignature); let outer_context = match extract_outer_context(&func) { @@ -61,8 +79,8 @@ pub(crate) fn expand_refine(kind: RefineKind, attr: TokenStream, item: TokenStre }; let attr_tokens: Vec = TokenStream2::from(attr).into_iter().collect(); - let jobs = match build_refine_jobs(kind, &func, &attr_tokens) { - Ok(jobs) => jobs, + let type_exprs = match annotated_type_exprs(kind, &func, &attr_tokens) { + Ok(exprs) => exprs, Err(e) => { let err = e.to_compile_error(); return quote! { #err #func }.into(); @@ -70,8 +88,8 @@ pub(crate) fn expand_refine(kind: RefineKind, attr: TokenStream, item: TokenStre }; let mut refinements = Vec::new(); - for (root_steps, ty_tokens) in jobs { - if let Err(e) = scan_type(&ty_tokens, root_steps, &mut refinements) { + for expr in type_exprs { + if let Err(e) = scan_type(&expr.tokens, expr.root, &mut refinements) { let err = e.to_compile_error(); return quote! { #err #func }.into(); } @@ -115,33 +133,37 @@ pub(crate) fn expand_refine(kind: RefineKind, attr: TokenStream, item: TokenStre .into() } -/// Builds `(root_steps, type_tokens)` jobs to scan from the attribute tokens. -/// -/// Each job's `root_steps` contains the initial [`RefineStep`]s that fix the -/// position of the type expression within the function signature (e.g. -/// `[Param(0)]` for the first parameter). [`scan_type`] will append further -/// [`RefineStep::TypeArg`] steps as it descends into generic type arguments. -fn build_refine_jobs( - kind: RefineKind, +/// Turns an annotation into the type expressions to scan, each anchored at its +/// root position within the function signature. +fn annotated_type_exprs( + kind: AnnotationKind, func: &FnItemWithSignature, attr_tokens: &[TokenTree2], -) -> syn::Result, Vec)>> { +) -> syn::Result> { + let at_param = |func: &FnItemWithSignature, nt: NamedType| -> syn::Result { + let idx = param_index(func, &nt.name)?; + Ok(PositionedTypeExpr { + root: vec![PositionStep::Param(idx)], + tokens: nt.tokens, + }) + }; match kind { - RefineKind::Param => { - let (name, ty_tokens) = split_name_type(attr_tokens)?; - let idx = param_index(func, &name)?; - Ok(vec![(vec![RefineStep::Param(idx)], ty_tokens)]) - } - RefineKind::Ret => Ok(vec![(vec![RefineStep::Return], attr_tokens.to_vec())]), - RefineKind::Sig => { - let (args, ret_tokens) = parse_sig_attr(attr_tokens)?; - let mut jobs = Vec::new(); - for (name, ty_tokens) in args { - let idx = param_index(func, &name)?; - jobs.push((vec![RefineStep::Param(idx)], ty_tokens)); + AnnotationKind::Param => Ok(vec![at_param(func, split_name_type(attr_tokens)?)?]), + AnnotationKind::Ret => Ok(vec![PositionedTypeExpr { + root: vec![PositionStep::Return], + tokens: attr_tokens.to_vec(), + }]), + AnnotationKind::Sig => { + let sig = parse_sig_attr(attr_tokens)?; + let mut exprs = Vec::new(); + for param in sig.params { + exprs.push(at_param(func, param)?); } - jobs.push((vec![RefineStep::Return], ret_tokens)); - Ok(jobs) + exprs.push(PositionedTypeExpr { + root: vec![PositionStep::Return], + tokens: sig.ret, + }); + Ok(exprs) } } } @@ -157,7 +179,7 @@ fn param_index(func: &FnItemWithSignature, name: &syn::Ident) -> syn::Result` from a flat token slice. -fn split_name_type(tokens: &[TokenTree2]) -> syn::Result<(syn::Ident, Vec)> { +fn split_name_type(tokens: &[TokenTree2]) -> syn::Result { let name = match tokens.first() { Some(TokenTree2::Ident(id)) => id.clone(), _ => return Err(err_tokens(tokens, "expected a parameter name")), @@ -166,14 +188,20 @@ fn split_name_type(tokens: &[TokenTree2]) -> syn::Result<(syn::Ident, Vec {} _ => return Err(err_tokens(tokens, "expected `:` after parameter name")), } - Ok((name, tokens[2..].to_vec())) + Ok(NamedType { + name, + tokens: tokens[2..].to_vec(), + }) } -/// Parses `fn ( n0: t0 , ... ) -> ret` into `((name, ty_tokens)*, ret_tokens)`. -#[allow(clippy::type_complexity)] -fn parse_sig_attr( - tokens: &[TokenTree2], -) -> syn::Result<(Vec<(syn::Ident, Vec)>, Vec)> { +/// The parsed parts of a `fn ( n0: t0 , ... ) -> ret` signature annotation. +struct SigAnnotation { + params: Vec, + ret: Vec, +} + +/// Parses `fn ( n0: t0 , ... ) -> ret`. +fn parse_sig_attr(tokens: &[TokenTree2]) -> syn::Result { match tokens.first() { Some(TokenTree2::Ident(id)) if id == "fn" => {} _ => return Err(err_tokens(tokens, "expected `fn` in sig annotation")), @@ -183,13 +211,13 @@ fn parse_sig_attr( _ => return Err(err_tokens(tokens, "expected `(..)` after `fn`")), }; - let mut args = Vec::new(); + let mut params = Vec::new(); let arg_tokens: Vec = arg_group.stream().into_iter().collect(); for arg in split_top_level_commas(&arg_tokens) { if arg.is_empty() { continue; } - args.push(split_name_type(&arg)?); + params.push(split_name_type(&arg)?); } // expect `->` then the return type @@ -207,7 +235,10 @@ fn parse_sig_attr( )) } } - Ok((args, rest.to_vec())) + Ok(SigAnnotation { + params, + ret: rest.to_vec(), + }) } /// Scans a type expression and records every refinement node with its full @@ -216,10 +247,10 @@ fn parse_sig_attr( /// `steps` holds the path from the function root to the current type node. /// When a refinement `{binder: ty | formula}` is found the current `steps` are /// recorded; when descending into generic type arguments a -/// [`RefineStep::TypeArg`]`(i)` step is appended to `steps`. +/// [`PositionStep::TypeArg`]`(i)` step is appended to `steps`. fn scan_type( tokens: &[TokenTree2], - steps: Vec, + steps: Vec, out: &mut Vec, ) -> syn::Result<()> { if tokens.is_empty() { @@ -230,15 +261,15 @@ fn scan_type( if tokens.len() == 1 { if let TokenTree2::Group(g) = &tokens[0] { if g.delimiter() == proc_macro2::Delimiter::Brace { - let (binder, binder_ty, formula) = split_refinement(g.stream())?; + let (binder, formula) = split_refinement(g.stream())?; out.push(Refinement { steps: steps.clone(), - binder, - binder_ty: binder_ty.iter().cloned().collect(), + binder: binder.name, + binder_ty: binder.tokens.iter().cloned().collect(), formula, }); // Descend into the binder type for nested refinements. - scan_type(&binder_ty, steps, out)?; + scan_type(&binder.tokens, steps, out)?; return Ok(()); } } @@ -254,7 +285,7 @@ fn scan_type( continue; } let mut child = steps.clone(); - child.push(RefineStep::TypeArg(type_idx)); + child.push(PositionStep::TypeArg(type_idx)); scan_type(&arg, child, out)?; type_idx += 1; } @@ -265,18 +296,16 @@ fn scan_type( Ok(()) } -/// Splits `{ binder : ty | formula }` contents into its parts. -fn split_refinement( - stream: TokenStream2, -) -> syn::Result<(syn::Ident, Vec, TokenStream2)> { +/// Splits `{ binder : ty | formula }` into its binder and formula expression. +fn split_refinement(stream: TokenStream2) -> syn::Result<(NamedType, TokenStream2)> { let toks: Vec = stream.into_iter().collect(); let bar = toks .iter() .position(|tt| matches!(tt, TokenTree2::Punct(p) if p.as_char() == '|')) .ok_or_else(|| err_tokens(&toks, "refinement type must contain `|`"))?; - let (binder, binder_ty) = split_name_type(&toks[..bar])?; + let binder = split_name_type(&toks[..bar])?; let formula: TokenStream2 = toks[bar + 1..].iter().cloned().collect(); - Ok((binder, binder_ty, formula)) + Ok((binder, formula)) } /// Splits the tokens following an opening `<` at top level by commas, stopping @@ -366,9 +395,9 @@ fn refine_fn_name(func: &FnItemWithSignature, r: &Refinement) -> syn::Ident { .steps .iter() .map(|s| match s { - RefineStep::Param(i) => format!("p{}", i), - RefineStep::Return => "ret".to_string(), - RefineStep::TypeArg(i) => format!("t{}", i), + PositionStep::Param(i) => format!("p{}", i), + PositionStep::Return => "ret".to_string(), + PositionStep::TypeArg(i) => format!("t{}", i), }) .collect::>() .join("_"); @@ -411,12 +440,12 @@ fn refine_path_stmt(func: &FnItemWithSignature, r: &Refinement) -> TokenStream2 quote!() }; let encoded_steps = r.steps.iter().map(|s| match s { - RefineStep::Param(i) => { + PositionStep::Param(i) => { let lit = proc_macro2::Literal::usize_unsuffixed(*i); quote!($#lit) } - RefineStep::Return => quote!(result), - RefineStep::TypeArg(i) => { + PositionStep::Return => quote!(result), + PositionStep::TypeArg(i) => { let lit = proc_macro2::Literal::usize_unsuffixed(*i); quote!(#lit) } From 70a8070f64a6e1970b30ad2e20b32283125590af Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 27 May 2026 16:23:55 +0000 Subject: [PATCH 08/18] Refine type-position terminology and drop non-empty invariant - Replace the undefined phrase "refine position" in parser diagnostics with "type position". - Drop the non-empty invariant from TypePosition: an empty path is a valid notion (it addresses the type itself); a path is only non-empty when applied to a function type. TypePosition::new now takes the full step vector, and that non-emptiness is checked where it matters (the function-type builder). - Rename the macro module file thrust-macros/src/refine.rs to rty.rs and use TypePositionStep there, mirroring the plugin's rty module instead of naming the same concept differently. https://claude.ai/code/session_01Km9xwaVaGQjnAy1iHsPewa --- src/analyze/annot.rs | 16 +++++------- src/refine/template.rs | 4 +-- src/rty.rs | 19 ++++++-------- thrust-macros/src/lib.rs | 8 +++--- thrust-macros/src/{refine.rs => rty.rs} | 34 ++++++++++++------------- 5 files changed, 38 insertions(+), 43 deletions(-) rename thrust-macros/src/{refine.rs => rty.rs} (94%) diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index 00c5238d..c278e9cb 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -241,19 +241,19 @@ pub fn parse_type_position(ts: &TokenStream) -> rty::TypePosition { assert_eq!( lit.kind, LitKind::Integer, - "expected an integer in refine position" + "expected an integer in type position" ); lit.symbol .as_str() .parse() - .expect("invalid integer in refine position") + .expect("invalid integer in type position") }; let mut steps = Vec::new(); let mut iter = ts.iter(); while let Some(tt) = iter.next() { let TokenTree::Token(t, _) = tt else { - panic!("unexpected token tree in refine position"); + panic!("unexpected token tree in type position"); }; match &t.kind { TokenKind::Comma => {} @@ -264,22 +264,20 @@ pub fn parse_type_position(ts: &TokenStream) -> rty::TypePosition { let i = match iter.next() { Some(TokenTree::Token(t, _)) => match &t.kind { TokenKind::Literal(lit) => parse_int(lit), - _ => panic!("expected integer after `$` in refine position: {:?}", t), + _ => panic!("expected integer after `$` in type position: {:?}", t), }, - _ => panic!("expected integer after `$` in refine position"), + _ => panic!("expected integer after `$` in type position"), }; steps.push(rty::TypePositionStep::Param(rty::FunctionParamIdx::from(i))); } TokenKind::Literal(lit) => { steps.push(rty::TypePositionStep::TypeArg(parse_int(lit))); } - _ => panic!("unexpected token in refine position: {:?}", t), + _ => panic!("unexpected token in type position: {:?}", t), } } - assert!(!steps.is_empty(), "empty refine position"); - let first = steps.remove(0); - rty::TypePosition::new(first, steps) + rty::TypePosition::new(steps) } pub fn split_param(ts: &TokenStream) -> (Ident, TokenStream) { diff --git a/src/refine/template.rs b/src/refine/template.rs index e88008a6..37956784 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -578,7 +578,7 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> { ) -> &mut Self { let (first, rest) = match position.steps().split_first() { Some(pair) => pair, - None => panic!("empty TypePosition"), + None => panic!("type position applied to a function type must not be empty"), }; match first { rty::TypePositionStep::Param(idx) => { @@ -603,7 +603,7 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> { .install_refinement_at(rest, refinement); } rty::TypePositionStep::TypeArg(_) => { - panic!("TypePosition must start with Param or Return, not TypeArg"); + panic!("type position applied to a function type must start with a param or result step, not a type argument"); } } self diff --git a/src/rty.rs b/src/rty.rs index c0c4efd7..abfff7bf 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -117,15 +117,14 @@ impl std::fmt::Display for TypePositionStep { } } -/// A path addressing a sub-type in a function's type signature, used to attach -/// a refinement. +/// A path addressing a sub-type within a type, used to attach a refinement. /// -/// The first step must be [`TypePositionStep::Param`] or -/// [`TypePositionStep::Return`] (selecting which slot of the top-level function -/// type to enter). Subsequent steps can freely combine -/// [`TypePositionStep::Param`] / [`TypePositionStep::Return`] (for -/// function-typed positions) and [`TypePositionStep::TypeArg`] (for generic -/// types), enabling positions inside higher-order function types. +/// An empty path addresses the type itself. Each step descends one level: +/// [`TypePositionStep::Param`] / [`TypePositionStep::Return`] enter a function +/// type's parameter or return slot, and [`TypePositionStep::TypeArg`] enters a +/// generic type argument. Steps combine freely, so positions inside +/// higher-order function types are expressible. A path applied to a function +/// type is therefore non-empty, beginning with a `Param`/`Return` step. /// /// Examples (function `fn f(x: List) -> Box`): /// - `$0` — parameter `x`. @@ -139,9 +138,7 @@ pub struct TypePosition { } impl TypePosition { - pub fn new(first: TypePositionStep, rest: Vec) -> Self { - let mut steps = vec![first]; - steps.extend(rest); + pub fn new(steps: Vec) -> Self { TypePosition { steps } } diff --git a/thrust-macros/src/lib.rs b/thrust-macros/src/lib.rs index a90ede1a..6fb0bbae 100644 --- a/thrust-macros/src/lib.rs +++ b/thrust-macros/src/lib.rs @@ -7,7 +7,7 @@ use syn::{ }; mod invariant; -mod refine; +mod rty; #[derive(Debug, Clone)] enum FnOuterItem { @@ -574,17 +574,17 @@ impl ExpandedTokens { #[proc_macro_attribute] pub fn param(attr: TokenStream, item: TokenStream) -> TokenStream { - refine::expand(refine::AnnotationKind::Param, attr, item) + rty::expand(rty::AnnotationKind::Param, attr, item) } #[proc_macro_attribute] pub fn ret(attr: TokenStream, item: TokenStream) -> TokenStream { - refine::expand(refine::AnnotationKind::Ret, attr, item) + rty::expand(rty::AnnotationKind::Ret, attr, item) } #[proc_macro_attribute] pub fn sig(attr: TokenStream, item: TokenStream) -> TokenStream { - refine::expand(refine::AnnotationKind::Sig, attr, item) + rty::expand(rty::AnnotationKind::Sig, attr, item) } fn mentions_self(sig: &syn::Signature) -> bool { diff --git a/thrust-macros/src/refine.rs b/thrust-macros/src/rty.rs similarity index 94% rename from thrust-macros/src/refine.rs rename to thrust-macros/src/rty.rs index 102f53e4..e32031e7 100644 --- a/thrust-macros/src/refine.rs +++ b/thrust-macros/src/rty.rs @@ -21,14 +21,14 @@ use super::{ /// One step in a refinement's type-position path. /// -/// Mirrors [`rty::TypePositionStep`] on the plugin side and uses the same -/// attribute encoding: +/// Mirrors the plugin's `rty::TypePositionStep` and uses the same attribute +/// encoding: /// - [`Param`](Self::Param) / [`Return`](Self::Return) navigate into a function /// type; encoded as `$i` / the `result` keyword. /// - [`TypeArg`](Self::TypeArg) navigates into a generic type argument; encoded /// as a bare integer `i`. #[derive(Clone, Copy)] -enum PositionStep { +enum TypePositionStep { Param(usize), Return, TypeArg(usize), @@ -37,7 +37,7 @@ enum PositionStep { #[derive(Clone)] struct Refinement { /// Full type-position path from the function root to the refined type. - steps: Vec, + steps: Vec, binder: syn::Ident, binder_ty: TokenStream2, formula: TokenStream2, @@ -56,7 +56,7 @@ pub(crate) enum AnnotationKind { struct PositionedTypeExpr { /// Steps locating the root of `tokens` (e.g. `[Param(0)]` for the first /// parameter); [`scan_type`] appends `TypeArg` steps as it descends. - root: Vec, + root: Vec, tokens: Vec, } @@ -143,14 +143,14 @@ fn annotated_type_exprs( let at_param = |func: &FnItemWithSignature, nt: NamedType| -> syn::Result { let idx = param_index(func, &nt.name)?; Ok(PositionedTypeExpr { - root: vec![PositionStep::Param(idx)], + root: vec![TypePositionStep::Param(idx)], tokens: nt.tokens, }) }; match kind { AnnotationKind::Param => Ok(vec![at_param(func, split_name_type(attr_tokens)?)?]), AnnotationKind::Ret => Ok(vec![PositionedTypeExpr { - root: vec![PositionStep::Return], + root: vec![TypePositionStep::Return], tokens: attr_tokens.to_vec(), }]), AnnotationKind::Sig => { @@ -160,7 +160,7 @@ fn annotated_type_exprs( exprs.push(at_param(func, param)?); } exprs.push(PositionedTypeExpr { - root: vec![PositionStep::Return], + root: vec![TypePositionStep::Return], tokens: sig.ret, }); Ok(exprs) @@ -247,10 +247,10 @@ fn parse_sig_attr(tokens: &[TokenTree2]) -> syn::Result { /// `steps` holds the path from the function root to the current type node. /// When a refinement `{binder: ty | formula}` is found the current `steps` are /// recorded; when descending into generic type arguments a -/// [`PositionStep::TypeArg`]`(i)` step is appended to `steps`. +/// [`TypePositionStep::TypeArg`]`(i)` step is appended to `steps`. fn scan_type( tokens: &[TokenTree2], - steps: Vec, + steps: Vec, out: &mut Vec, ) -> syn::Result<()> { if tokens.is_empty() { @@ -285,7 +285,7 @@ fn scan_type( continue; } let mut child = steps.clone(); - child.push(PositionStep::TypeArg(type_idx)); + child.push(TypePositionStep::TypeArg(type_idx)); scan_type(&arg, child, out)?; type_idx += 1; } @@ -395,9 +395,9 @@ fn refine_fn_name(func: &FnItemWithSignature, r: &Refinement) -> syn::Ident { .steps .iter() .map(|s| match s { - PositionStep::Param(i) => format!("p{}", i), - PositionStep::Return => "ret".to_string(), - PositionStep::TypeArg(i) => format!("t{}", i), + TypePositionStep::Param(i) => format!("p{}", i), + TypePositionStep::Return => "ret".to_string(), + TypePositionStep::TypeArg(i) => format!("t{}", i), }) .collect::>() .join("_"); @@ -440,12 +440,12 @@ fn refine_path_stmt(func: &FnItemWithSignature, r: &Refinement) -> TokenStream2 quote!() }; let encoded_steps = r.steps.iter().map(|s| match s { - PositionStep::Param(i) => { + TypePositionStep::Param(i) => { let lit = proc_macro2::Literal::usize_unsuffixed(*i); quote!($#lit) } - PositionStep::Return => quote!(result), - PositionStep::TypeArg(i) => { + TypePositionStep::Return => quote!(result), + TypePositionStep::TypeArg(i) => { let lit = proc_macro2::Literal::usize_unsuffixed(*i); quote!(#lit) } From 2322d32668208516ad6b1c78fae6c91fd6b25e91 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 30 May 2026 03:30:55 +0000 Subject: [PATCH 09/18] Fix mis-named conversion at refinement_path extraction site extract_refinement_annots was using FormulaFn::to_ensure_annot to build the refinement, which is a misleading name for that call site (a refinement_path is not an ensures annotation). Introduce a properly-named FormulaFn::to_refinement that returns a rty::Refinement directly, removing the awkward AnnotFormula unwrap-or-panic on what was always the Formula variant. The shared var mapping (param 0 -> Value, rest -> Free) lives in a private helper used by both to_ensure_annot and to_refinement. Also rename FunctionTemplateTypeBuilder::install_refinement_at to refinement_at to match the builder's existing param_refinement / ret_refinement convention. https://claude.ai/code/session_01Km9xwaVaGQjnAy1iHsPewa --- src/analyze.rs | 5 +---- src/analyze/annot_fn.rs | 24 ++++++++++++++++++++++-- src/analyze/local_def.rs | 2 +- src/refine/template.rs | 4 ++-- 4 files changed, 26 insertions(+), 9 deletions(-) diff --git a/src/analyze.rs b/src/analyze.rs index e7420df9..6ad1719e 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -864,10 +864,7 @@ impl<'tcx> Analyzer<'tcx> { formula_def_id ); }; - let AnnotFormula::Formula(formula) = formula_fn.to_ensure_annot() else { - panic!("refinement_path annotation must lower to a plain formula"); - }; - out.push((position, formula.into())); + out.push((position, formula_fn.to_refinement())); } out } diff --git a/src/analyze/annot_fn.rs b/src/analyze/annot_fn.rs index 5503382f..b5399397 100644 --- a/src/analyze/annot_fn.rs +++ b/src/analyze/annot_fn.rs @@ -49,13 +49,33 @@ impl<'tcx> FormulaFn<'tcx> { } pub fn to_ensure_annot(&self) -> AnnotFormula> { - AnnotFormula::Formula(self.formula.clone().map_var(|v| { + AnnotFormula::Formula(self.formula_with_value_binder()) + } + + /// Lowers this formula function into a [`rty::Refinement`] on the enclosing + /// function's parameters. + /// + /// The formula function's first parameter is treated as the bound value + /// ([`rty::RefinedTypeVar::Value`]); the remaining parameters are the free + /// variables, mapped to [`rty::RefinedTypeVar::Free`] of the enclosing + /// function's parameters. + pub fn to_refinement(&self) -> rty::Refinement { + self.formula_with_value_binder().into() + } + + /// Re-maps this formula function's parameters as a refinement: + /// param 0 → [`rty::RefinedTypeVar::Value`], param `i > 0` → free variable + /// `i - 1` of the enclosing function. + fn formula_with_value_binder( + &self, + ) -> chc::Formula> { + self.formula.clone().map_var(|v| { if v.as_usize() == 0 { rty::RefinedTypeVar::Value } else { rty::RefinedTypeVar::Free(rty::FunctionParamIdx::from(v.as_usize() - 1)) } - })) + }) } } diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index b3472cb6..c8c700a6 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -436,7 +436,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { builder.ret_rty(ret_rty); } for (position, refinement) in refinement_annots { - builder.install_refinement_at(&position, refinement); + builder.refinement_at(&position, refinement); } if is_fully_annotated { diff --git a/src/refine/template.rs b/src/refine/template.rs index 37956784..fe200bda 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -566,12 +566,12 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> { self } - /// Installs a refinement at a [`rty::TypePosition`]. + /// Records a refinement to install at a [`rty::TypePosition`]. /// /// The first step must be [`rty::TypePositionStep::Param`] or /// [`rty::TypePositionStep::Return`]; the remaining steps are forwarded to /// [`rty::RefinedType::install_refinement_at`]. - pub fn install_refinement_at( + pub fn refinement_at( &mut self, position: &rty::TypePosition, refinement: rty::Refinement, From 076831d9cca9be45736f5609b4da3045830d421b Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 30 May 2026 05:15:08 +0000 Subject: [PATCH 10/18] Test refinement_path on a generic ADT's element type Add pass/fail pairs for both the param and sig macro forms, using a generic Pair enum monomorphized to Pair in the signature and a refinement annotation that supplies the element type explicitly as Pair<{ v: i32 | v > 0 }>. Exercises the TypeArg navigation step into enum type arguments. https://claude.ai/code/session_01Km9xwaVaGQjnAy1iHsPewa --- tests/ui/fail/refine_param_generic_adt.rs | 19 +++++++++++++++++++ tests/ui/fail/refine_sig_generic_adt.rs | 18 ++++++++++++++++++ tests/ui/pass/refine_param_generic_adt.rs | 19 +++++++++++++++++++ tests/ui/pass/refine_sig_generic_adt.rs | 18 ++++++++++++++++++ 4 files changed, 74 insertions(+) create mode 100644 tests/ui/fail/refine_param_generic_adt.rs create mode 100644 tests/ui/fail/refine_sig_generic_adt.rs create mode 100644 tests/ui/pass/refine_param_generic_adt.rs create mode 100644 tests/ui/pass/refine_sig_generic_adt.rs diff --git a/tests/ui/fail/refine_param_generic_adt.rs b/tests/ui/fail/refine_param_generic_adt.rs new file mode 100644 index 00000000..9c7673e1 --- /dev/null +++ b/tests/ui/fail/refine_param_generic_adt.rs @@ -0,0 +1,19 @@ +//@error-in-other-file: Unsat + +pub enum Pair { + Mk(T, T), +} + +impl thrust_models::Model for Pair { + type Ty = Self; +} + +#[thrust_macros::param(p: Pair<{ v: i32 | v > 0 }>)] +#[thrust_macros::ret({ r: i32 | r > 100 })] +fn first(p: Pair) -> i32 { + match p { + Pair::Mk(a, _) => a, + } +} + +fn main() {} diff --git a/tests/ui/fail/refine_sig_generic_adt.rs b/tests/ui/fail/refine_sig_generic_adt.rs new file mode 100644 index 00000000..f038f744 --- /dev/null +++ b/tests/ui/fail/refine_sig_generic_adt.rs @@ -0,0 +1,18 @@ +//@error-in-other-file: Unsat + +pub enum Pair { + Mk(T, T), +} + +impl thrust_models::Model for Pair { + type Ty = Self; +} + +#[thrust_macros::sig(fn(p: Pair<{ v: i32 | v > 0 }>) -> { r: i32 | r > 100 })] +fn second(p: Pair) -> i32 { + match p { + Pair::Mk(_, b) => b, + } +} + +fn main() {} diff --git a/tests/ui/pass/refine_param_generic_adt.rs b/tests/ui/pass/refine_param_generic_adt.rs new file mode 100644 index 00000000..73f20579 --- /dev/null +++ b/tests/ui/pass/refine_param_generic_adt.rs @@ -0,0 +1,19 @@ +//@check-pass + +pub enum Pair { + Mk(T, T), +} + +impl thrust_models::Model for Pair { + type Ty = Self; +} + +#[thrust_macros::param(p: Pair<{ v: i32 | v > 0 }>)] +#[thrust_macros::ret({ r: i32 | r > 0 })] +fn first(p: Pair) -> i32 { + match p { + Pair::Mk(a, _) => a, + } +} + +fn main() {} diff --git a/tests/ui/pass/refine_sig_generic_adt.rs b/tests/ui/pass/refine_sig_generic_adt.rs new file mode 100644 index 00000000..9e1d7f39 --- /dev/null +++ b/tests/ui/pass/refine_sig_generic_adt.rs @@ -0,0 +1,18 @@ +//@check-pass + +pub enum Pair { + Mk(T, T), +} + +impl thrust_models::Model for Pair { + type Ty = Self; +} + +#[thrust_macros::sig(fn(p: Pair<{ v: i32 | v > 0 }>) -> { r: i32 | r > 0 })] +fn second(p: Pair) -> i32 { + match p { + Pair::Mk(_, b) => b, + } +} + +fn main() {} From 19b8514915dbc931b07b1ea63977aafb24bb5afc Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 30 May 2026 05:38:45 +0000 Subject: [PATCH 11/18] Dedupe TypePositionStep docs, fix leftover [i] syntax, separate annot conventions - Centralize per-variant explanations on rty::TypePositionStep itself and reduce the duplicated copies in TypePosition, install_refinement_at, the macro's TypePositionStep, and parse_type_position to thin pointers back to the canonical source. - Fix "Box type position must be [0]" assert message left over from the earlier bracket-encoded syntax. - Drop FormulaFn::formula_with_value_binder. The 0->Value, i>0->Free(i-1) mapping is justified by the layout each macro generates (ensures puts the result at index 0; the refinement-type macros put the binder at index 0), not by a shared abstraction. Inline the mapping into to_ensure_annot and to_refinement, with each method's doc stating which macro's layout it relies on. https://claude.ai/code/session_01Km9xwaVaGQjnAy1iHsPewa --- src/analyze/annot.rs | 14 +++-------- src/analyze/annot_fn.rs | 52 +++++++++++++++++++++++----------------- src/rty.rs | 21 ++++++---------- thrust-macros/src/rty.rs | 11 ++++----- 4 files changed, 44 insertions(+), 54 deletions(-) diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index c278e9cb..ee65d06f 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -222,17 +222,9 @@ pub fn extract_annot_tokens(attr: Attribute) -> TokenStream { /// Parses a [`rty::TypePosition`] from the tokens of a /// `#[thrust::refinement_path(..)]` attribute. /// -/// Tokens are comma-separated steps. Each step is one of: -/// - The keyword `result` → [`rty::TypePositionStep::Return`] (navigate to a -/// function type's return slot). -/// - `$i` (a `$` followed by an integer) → [`rty::TypePositionStep::Param`]`(i)` -/// (navigate to the `i`-th parameter of a function type). -/// - A bare integer `i` → [`rty::TypePositionStep::TypeArg`]`(i)` (navigate to -/// the `i`-th type argument of a generic type such as an enum or `Box`). -/// -/// Examples: `result` is the return; `$0` is the first parameter; `$0, 0` is -/// the first type-argument of the first parameter; `$0, result` is the return -/// of a function-typed first parameter. +/// Tokens are comma-separated [`rty::TypePositionStep`]s, each encoded as +/// `result` (→ `Return`), `$i` (→ `Param(i)`), or a bare integer `i` (→ +/// `TypeArg(i)`). pub fn parse_type_position(ts: &TokenStream) -> rty::TypePosition { use rustc_ast::token::{LitKind, TokenKind}; use rustc_ast::tokenstream::TokenTree; diff --git a/src/analyze/annot_fn.rs b/src/analyze/annot_fn.rs index b5399397..0f8000f8 100644 --- a/src/analyze/annot_fn.rs +++ b/src/analyze/annot_fn.rs @@ -48,34 +48,42 @@ impl<'tcx> FormulaFn<'tcx> { AnnotFormula::Formula(self.formula.clone()) } - pub fn to_ensure_annot(&self) -> AnnotFormula> { - AnnotFormula::Formula(self.formula_with_value_binder()) - } - - /// Lowers this formula function into a [`rty::Refinement`] on the enclosing - /// function's parameters. + /// Lowers an `ensures` formula function into a postcondition annotation. /// - /// The formula function's first parameter is treated as the bound value - /// ([`rty::RefinedTypeVar::Value`]); the remaining parameters are the free - /// variables, mapped to [`rty::RefinedTypeVar::Free`] of the enclosing - /// function's parameters. - pub fn to_refinement(&self) -> rty::Refinement { - self.formula_with_value_binder().into() - } - - /// Re-maps this formula function's parameters as a refinement: - /// param 0 → [`rty::RefinedTypeVar::Value`], param `i > 0` → free variable - /// `i - 1` of the enclosing function. - fn formula_with_value_binder( - &self, - ) -> chc::Formula> { - self.formula.clone().map_var(|v| { + /// Relies on the layout produced by `thrust_macros::ensures`: parameter `0` + /// is the function's return value (bound to [`rty::RefinedTypeVar::Value`]) + /// and parameters `1..n` are the enclosing function's parameters in order + /// (mapped to [`rty::RefinedTypeVar::Free`]). + pub fn to_ensure_annot(&self) -> AnnotFormula> { + AnnotFormula::Formula(self.formula.clone().map_var(|v| { if v.as_usize() == 0 { rty::RefinedTypeVar::Value } else { rty::RefinedTypeVar::Free(rty::FunctionParamIdx::from(v.as_usize() - 1)) } - }) + })) + } + + /// Lowers a refinement-type formula function (generated by the `param` / + /// `ret` / `sig` macros) into a [`rty::Refinement`] on the enclosing + /// function's parameters. + /// + /// Relies on the layout produced by those macros: parameter `0` is the + /// refinement's value binder (bound to [`rty::RefinedTypeVar::Value`] at + /// the type position where the refinement is installed) and parameters + /// `1..n` are the enclosing function's parameters in order (mapped to + /// [`rty::RefinedTypeVar::Free`]). + pub fn to_refinement(&self) -> rty::Refinement { + self.formula + .clone() + .map_var(|v| { + if v.as_usize() == 0 { + rty::RefinedTypeVar::Value + } else { + rty::RefinedTypeVar::Free(rty::FunctionParamIdx::from(v.as_usize() - 1)) + } + }) + .into() } } diff --git a/src/rty.rs b/src/rty.rs index abfff7bf..558fe4e2 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -117,14 +117,11 @@ impl std::fmt::Display for TypePositionStep { } } -/// A path addressing a sub-type within a type, used to attach a refinement. +/// A path of [`TypePositionStep`]s addressing a sub-type within a type, used +/// to attach a refinement. /// -/// An empty path addresses the type itself. Each step descends one level: -/// [`TypePositionStep::Param`] / [`TypePositionStep::Return`] enter a function -/// type's parameter or return slot, and [`TypePositionStep::TypeArg`] enters a -/// generic type argument. Steps combine freely, so positions inside -/// higher-order function types are expressible. A path applied to a function -/// type is therefore non-empty, beginning with a `Param`/`Return` step. +/// An empty path addresses the type itself; a path applied to a function type +/// is therefore non-empty, beginning with a `Param` or `Return` step. /// /// Examples (function `fn f(x: List) -> Box`): /// - `$0` — parameter `x`. @@ -1555,12 +1552,8 @@ where impl RefinedType { /// Installs `refinement` at the sub-type addressed by `steps`. /// - /// An empty `steps` slice replaces the refinement at this node. Each step - /// in the slice navigates one level deeper: - /// - [`TypePositionStep::TypeArg`] descends into enum type arguments or the - /// `Box` pointee. - /// - [`TypePositionStep::Param`] / [`TypePositionStep::Return`] descend - /// into a function-typed position's parameter or return slot. + /// An empty `steps` slice replaces the refinement at this node; otherwise + /// each step navigates one level deeper per [`TypePositionStep`]. pub fn install_refinement_at( &mut self, steps: &[TypePositionStep], @@ -1579,7 +1572,7 @@ impl RefinedType { arg.install_refinement_at(rest, refinement); } Type::Pointer(p) => { - assert_eq!(*i, 0, "Box type position must be [0]"); + assert_eq!(*i, 0, "Box type position must be 0"); p.elem.install_refinement_at(rest, refinement); } ty => panic!("TypeArg step on unsupported type: {:?}", ty), diff --git a/thrust-macros/src/rty.rs b/thrust-macros/src/rty.rs index e32031e7..5ce1b3a8 100644 --- a/thrust-macros/src/rty.rs +++ b/thrust-macros/src/rty.rs @@ -19,14 +19,11 @@ use super::{ generic_turbofish, model_where_predicates, FnItemWithSignature, FnOuterItem, }; -/// One step in a refinement's type-position path. +/// One step in a refinement's type-position path; mirrors the plugin's +/// `rty::TypePositionStep`. /// -/// Mirrors the plugin's `rty::TypePositionStep` and uses the same attribute -/// encoding: -/// - [`Param`](Self::Param) / [`Return`](Self::Return) navigate into a function -/// type; encoded as `$i` / the `result` keyword. -/// - [`TypeArg`](Self::TypeArg) navigates into a generic type argument; encoded -/// as a bare integer `i`. +/// The attribute encoding emitted into `#[thrust::refinement_path(..)]` is: +/// `Param(i)` → `$i`, `Return` → `result`, `TypeArg(i)` → a bare integer `i`. #[derive(Clone, Copy)] enum TypePositionStep { Param(usize), From 37262c9b35c161b8d82517a7f00bede5b73257e5 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 2 Jun 2026 15:12:11 +0000 Subject: [PATCH 12/18] Loosen to_refinement doc to cover non-direct positions The previous wording claimed "parameters 1..n are the enclosing function's parameters in order", but that framing only fits direct parameter/return refinement; a refinement_path can also be installed at, say, an ADT type-argument position. Describe parameters 1..n as "the formula's free variables" without overspecifying their meaning. https://claude.ai/code/session_01Km9xwaVaGQjnAy1iHsPewa --- src/analyze/annot_fn.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/analyze/annot_fn.rs b/src/analyze/annot_fn.rs index 0f8000f8..e6157a7e 100644 --- a/src/analyze/annot_fn.rs +++ b/src/analyze/annot_fn.rs @@ -65,14 +65,12 @@ impl<'tcx> FormulaFn<'tcx> { } /// Lowers a refinement-type formula function (generated by the `param` / - /// `ret` / `sig` macros) into a [`rty::Refinement`] on the enclosing - /// function's parameters. + /// `ret` / `sig` macros) into a [`rty::Refinement`]. /// /// Relies on the layout produced by those macros: parameter `0` is the - /// refinement's value binder (bound to [`rty::RefinedTypeVar::Value`] at - /// the type position where the refinement is installed) and parameters - /// `1..n` are the enclosing function's parameters in order (mapped to - /// [`rty::RefinedTypeVar::Free`]). + /// refinement's value binder (mapped to [`rty::RefinedTypeVar::Value`]) + /// and the remaining parameters carry the formula's free variables, + /// mapped to [`rty::RefinedTypeVar::Free`]. pub fn to_refinement(&self) -> rty::Refinement { self.formula .clone() From 415bbdd9e9065d77aaff793b0ca0f62fc39a3ab2 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 2 Jun 2026 15:26:58 +0000 Subject: [PATCH 13/18] Split refinement-type macro into three concrete expanders Replace rty::expand + AnnotationKind dispatch with three separate proc-macro implementations (expand_param, expand_ret, expand_sig). Each parses its own attribute form (which differs fundamentally) and collects refinements via a kind-specific collector (collect_param_refinements / collect_ret_refinements / collect_sig_refinements). The only shared step is the final output assembly, extracted as expand_with_refinements with a clear concrete contract. Drop the ad-hoc aggregates that just named tuples: PositionedTypeExpr, NamedType, SigAnnotation, and AnnotationKind. The parsers now use plain tuples destructured at the call site (parse_name_typed_binding -> (name, ty_tokens), parse_refinement -> (binder, binder_ty, formula)) with names introduced at the point they have local meaning. Also rename codegen helpers to describe what they produce: refine_fn_name -> formula_fn_name, refine_formula_fn -> build_formula_fn, refine_path_stmt -> build_refinement_path_stmt. Document Refinement's fields so binder_ty / formula are no longer opaquely "tokens". https://claude.ai/code/session_01Km9xwaVaGQjnAy1iHsPewa --- thrust-macros/src/lib.rs | 6 +- thrust-macros/src/rty.rs | 363 +++++++++++++++++++++------------------ 2 files changed, 201 insertions(+), 168 deletions(-) diff --git a/thrust-macros/src/lib.rs b/thrust-macros/src/lib.rs index 6fb0bbae..5d527b5a 100644 --- a/thrust-macros/src/lib.rs +++ b/thrust-macros/src/lib.rs @@ -574,17 +574,17 @@ impl ExpandedTokens { #[proc_macro_attribute] pub fn param(attr: TokenStream, item: TokenStream) -> TokenStream { - rty::expand(rty::AnnotationKind::Param, attr, item) + rty::expand_param(attr, item) } #[proc_macro_attribute] pub fn ret(attr: TokenStream, item: TokenStream) -> TokenStream { - rty::expand(rty::AnnotationKind::Ret, attr, item) + rty::expand_ret(attr, item) } #[proc_macro_attribute] pub fn sig(attr: TokenStream, item: TokenStream) -> TokenStream { - rty::expand(rty::AnnotationKind::Sig, attr, item) + rty::expand_sig(attr, item) } fn mentions_self(sig: &syn::Signature) -> bool { diff --git a/thrust-macros/src/rty.rs b/thrust-macros/src/rty.rs index 5ce1b3a8..af2223b1 100644 --- a/thrust-macros/src/rty.rs +++ b/thrust-macros/src/rty.rs @@ -31,67 +31,162 @@ enum TypePositionStep { TypeArg(usize), } +/// A refinement node `{ binder : ty | formula }` extracted from a type +/// expression, together with its position within the function signature. #[derive(Clone)] struct Refinement { - /// Full type-position path from the function root to the refined type. + /// Path from the function root to the refined sub-type. steps: Vec, + /// The bound name (the `v` in `{ v: T | phi }`). binder: syn::Ident, + /// Tokens of the binder's type (`T` in `{ v: T | phi }`); kept as tokens + /// because nested refinement braces make it not always a valid Rust type. binder_ty: TokenStream2, + /// Tokens of the refinement formula expression (`phi` in `{ v: T | phi }`). formula: TokenStream2, } -/// Which refinement-type annotation is being expanded. -pub(crate) enum AnnotationKind { - Param, - Ret, - Sig, -} +// --------------------------------------------------------------------------- +// Macro entry points (called by lib.rs). +// --------------------------------------------------------------------------- -/// A type expression from the annotation, paired with the position of its root -/// within the function signature. [`scan_type`] walks each one to extract the -/// refinements it contains. -struct PositionedTypeExpr { - /// Steps locating the root of `tokens` (e.g. `[Param(0)]` for the first - /// parameter); [`scan_type`] appends `TypeArg` steps as it descends. - root: Vec, - tokens: Vec, +pub(crate) fn expand_param(attr: TokenStream, item: TokenStream) -> TokenStream { + let func = parse_macro_input!(item as FnItemWithSignature); + let outer_context = match extract_outer_context(&func) { + Ok(ctx) => ctx, + Err(e) => return emit_error(e, &func), + }; + let attr_tokens: Vec = TokenStream2::from(attr).into_iter().collect(); + let refinements = match collect_param_refinements(&func, &attr_tokens) { + Ok(r) => r, + Err(e) => return emit_error(e, &func), + }; + expand_with_refinements(func, outer_context, refinements) } -/// A `name : type` binding, e.g. a parameter in a `sig` annotation or the -/// binder of a refinement `{ name: type | .. }`. -struct NamedType { - name: syn::Ident, - tokens: Vec, +pub(crate) fn expand_ret(attr: TokenStream, item: TokenStream) -> TokenStream { + let func = parse_macro_input!(item as FnItemWithSignature); + let outer_context = match extract_outer_context(&func) { + Ok(ctx) => ctx, + Err(e) => return emit_error(e, &func), + }; + let attr_tokens: Vec = TokenStream2::from(attr).into_iter().collect(); + let refinements = match collect_ret_refinements(&attr_tokens) { + Ok(r) => r, + Err(e) => return emit_error(e, &func), + }; + expand_with_refinements(func, outer_context, refinements) } -pub(crate) fn expand(kind: AnnotationKind, attr: TokenStream, item: TokenStream) -> TokenStream { - let mut func = parse_macro_input!(item as FnItemWithSignature); - +pub(crate) fn expand_sig(attr: TokenStream, item: TokenStream) -> TokenStream { + let func = parse_macro_input!(item as FnItemWithSignature); let outer_context = match extract_outer_context(&func) { Ok(ctx) => ctx, - Err(e) => { - let err = e.to_compile_error(); - return quote! { #err #func }.into(); - } + Err(e) => return emit_error(e, &func), }; - let attr_tokens: Vec = TokenStream2::from(attr).into_iter().collect(); - let type_exprs = match annotated_type_exprs(kind, &func, &attr_tokens) { - Ok(exprs) => exprs, - Err(e) => { - let err = e.to_compile_error(); - return quote! { #err #func }.into(); - } + let refinements = match collect_sig_refinements(&func, &attr_tokens) { + Ok(r) => r, + Err(e) => return emit_error(e, &func), + }; + expand_with_refinements(func, outer_context, refinements) +} + +// --------------------------------------------------------------------------- +// Per-annotation refinement collectors. +// +// Each parses one specific attribute form and extracts the refinement nodes +// contained in its type expression(s). +// --------------------------------------------------------------------------- + +/// Parses a `#[param(name: ty)]` attribute and extracts the refinements in +/// `ty`, anchored at the position of the parameter `name`. +fn collect_param_refinements( + func: &FnItemWithSignature, + attr_tokens: &[TokenTree2], +) -> syn::Result> { + let (name, ty_tokens) = parse_name_typed_binding(attr_tokens)?; + let idx = param_index(func, &name)?; + let mut refinements = Vec::new(); + scan_type( + &ty_tokens, + vec![TypePositionStep::Param(idx)], + &mut refinements, + )?; + Ok(refinements) +} + +/// Parses a `#[ret(ty)]` attribute and extracts the refinements in `ty`, +/// anchored at the return position. +fn collect_ret_refinements(attr_tokens: &[TokenTree2]) -> syn::Result> { + let mut refinements = Vec::new(); + scan_type( + attr_tokens, + vec![TypePositionStep::Return], + &mut refinements, + )?; + Ok(refinements) +} + +/// Parses a `#[sig(fn(n0: t0, ..) -> r)]` attribute and extracts the +/// refinements in every parameter type and the return type, each anchored at +/// the corresponding position. +fn collect_sig_refinements( + func: &FnItemWithSignature, + attr_tokens: &[TokenTree2], +) -> syn::Result> { + match attr_tokens.first() { + Some(TokenTree2::Ident(id)) if id == "fn" => {} + _ => return Err(err_tokens(attr_tokens, "expected `fn` in sig annotation")), + } + let arg_group = match attr_tokens.get(1) { + Some(TokenTree2::Group(g)) if g.delimiter() == proc_macro2::Delimiter::Parenthesis => g, + _ => return Err(err_tokens(attr_tokens, "expected `(..)` after `fn`")), }; let mut refinements = Vec::new(); - for expr in type_exprs { - if let Err(e) = scan_type(&expr.tokens, expr.root, &mut refinements) { - let err = e.to_compile_error(); - return quote! { #err #func }.into(); + let arg_tokens: Vec = arg_group.stream().into_iter().collect(); + for arg in split_top_level_commas(&arg_tokens) { + if arg.is_empty() { + continue; } + let (name, ty_tokens) = parse_name_typed_binding(&arg)?; + let idx = param_index(func, &name)?; + scan_type( + &ty_tokens, + vec![TypePositionStep::Param(idx)], + &mut refinements, + )?; } + let rest = &attr_tokens[2..]; + let ret_tokens = match (rest.first(), rest.get(1)) { + (Some(TokenTree2::Punct(a)), Some(TokenTree2::Punct(b))) + if a.as_char() == '-' && b.as_char() == '>' => + { + &rest[2..] + } + _ => { + return Err(err_tokens( + attr_tokens, + "expected `->` and a return type in sig annotation", + )) + } + }; + scan_type(ret_tokens, vec![TypePositionStep::Return], &mut refinements)?; + Ok(refinements) +} + +// --------------------------------------------------------------------------- +// Macro-output assembly: from the collected refinements, produce the +// formula_fn declarations and the path-statement-injected function. +// --------------------------------------------------------------------------- + +fn expand_with_refinements( + mut func: FnItemWithSignature, + outer_context: Option, + refinements: Vec, +) -> TokenStream { if refinements.is_empty() { return func.into_token_stream().into(); } @@ -103,8 +198,8 @@ pub(crate) fn expand(kind: AnnotationKind, attr: TokenStream, item: TokenStream) if has_receiver { r.formula = rewrite_self_in_tokens(r.formula); } - formula_fns.push(refine_formula_fn(&func, outer_context.as_ref(), &r)); - path_stmts.push(refine_path_stmt(&func, &r)); + formula_fns.push(build_formula_fn(&func, outer_context.as_ref(), &r)); + path_stmts.push(build_refinement_path_stmt(&func, &r)); } let Some(block) = func.block_mut() else { @@ -130,121 +225,19 @@ pub(crate) fn expand(kind: AnnotationKind, attr: TokenStream, item: TokenStream) .into() } -/// Turns an annotation into the type expressions to scan, each anchored at its -/// root position within the function signature. -fn annotated_type_exprs( - kind: AnnotationKind, - func: &FnItemWithSignature, - attr_tokens: &[TokenTree2], -) -> syn::Result> { - let at_param = |func: &FnItemWithSignature, nt: NamedType| -> syn::Result { - let idx = param_index(func, &nt.name)?; - Ok(PositionedTypeExpr { - root: vec![TypePositionStep::Param(idx)], - tokens: nt.tokens, - }) - }; - match kind { - AnnotationKind::Param => Ok(vec![at_param(func, split_name_type(attr_tokens)?)?]), - AnnotationKind::Ret => Ok(vec![PositionedTypeExpr { - root: vec![TypePositionStep::Return], - tokens: attr_tokens.to_vec(), - }]), - AnnotationKind::Sig => { - let sig = parse_sig_attr(attr_tokens)?; - let mut exprs = Vec::new(); - for param in sig.params { - exprs.push(at_param(func, param)?); - } - exprs.push(PositionedTypeExpr { - root: vec![TypePositionStep::Return], - tokens: sig.ret, - }); - Ok(exprs) - } - } -} - -fn param_index(func: &FnItemWithSignature, name: &syn::Ident) -> syn::Result { - let pos = func.sig().inputs.iter().position(|arg| match arg { - FnArg::Receiver(_) => name == "self", - FnArg::Typed(pt) => matches!(&*pt.pat, syn::Pat::Ident(pi) if &pi.ident == name), - }); - pos.ok_or_else(|| { - syn::Error::new_spanned(name, format!("no parameter named `{}` in signature", name)) - }) -} - -/// Parses `name : ` from a flat token slice. -fn split_name_type(tokens: &[TokenTree2]) -> syn::Result { - let name = match tokens.first() { - Some(TokenTree2::Ident(id)) => id.clone(), - _ => return Err(err_tokens(tokens, "expected a parameter name")), - }; - match tokens.get(1) { - Some(TokenTree2::Punct(p)) if p.as_char() == ':' => {} - _ => return Err(err_tokens(tokens, "expected `:` after parameter name")), - } - Ok(NamedType { - name, - tokens: tokens[2..].to_vec(), - }) -} - -/// The parsed parts of a `fn ( n0: t0 , ... ) -> ret` signature annotation. -struct SigAnnotation { - params: Vec, - ret: Vec, +fn emit_error(err: syn::Error, func: &FnItemWithSignature) -> TokenStream { + let err = err.to_compile_error(); + quote! { #err #func }.into() } -/// Parses `fn ( n0: t0 , ... ) -> ret`. -fn parse_sig_attr(tokens: &[TokenTree2]) -> syn::Result { - match tokens.first() { - Some(TokenTree2::Ident(id)) if id == "fn" => {} - _ => return Err(err_tokens(tokens, "expected `fn` in sig annotation")), - } - let arg_group = match tokens.get(1) { - Some(TokenTree2::Group(g)) if g.delimiter() == proc_macro2::Delimiter::Parenthesis => g, - _ => return Err(err_tokens(tokens, "expected `(..)` after `fn`")), - }; - - let mut params = Vec::new(); - let arg_tokens: Vec = arg_group.stream().into_iter().collect(); - for arg in split_top_level_commas(&arg_tokens) { - if arg.is_empty() { - continue; - } - params.push(split_name_type(&arg)?); - } +// --------------------------------------------------------------------------- +// Type-expression scanner: walks a type expression in `name: ty` / `ty` form, +// recording every refinement node and the path that locates it. +// --------------------------------------------------------------------------- - // expect `->` then the return type - let mut rest = &tokens[2..]; - match (rest.first(), rest.get(1)) { - (Some(TokenTree2::Punct(a)), Some(TokenTree2::Punct(b))) - if a.as_char() == '-' && b.as_char() == '>' => - { - rest = &rest[2..]; - } - _ => { - return Err(err_tokens( - tokens, - "expected `->` and a return type in sig annotation", - )) - } - } - Ok(SigAnnotation { - params, - ret: rest.to_vec(), - }) -} - -/// Scans a type expression and records every refinement node with its full -/// type-position path (`steps`). -/// -/// `steps` holds the path from the function root to the current type node. -/// When a refinement `{binder: ty | formula}` is found the current `steps` are -/// recorded; when descending into generic type arguments a -/// [`TypePositionStep::TypeArg`]`(i)` step is appended to `steps`. +/// Walks the type expression `tokens` and appends any refinement nodes it +/// contains to `out`. `steps` is the path to `tokens` itself; descending into a +/// generic type argument extends it with a [`TypePositionStep::TypeArg`]. fn scan_type( tokens: &[TokenTree2], steps: Vec, @@ -258,15 +251,15 @@ fn scan_type( if tokens.len() == 1 { if let TokenTree2::Group(g) = &tokens[0] { if g.delimiter() == proc_macro2::Delimiter::Brace { - let (binder, formula) = split_refinement(g.stream())?; + let (binder, binder_ty, formula) = parse_refinement(g.stream())?; out.push(Refinement { steps: steps.clone(), - binder: binder.name, - binder_ty: binder.tokens.iter().cloned().collect(), + binder, + binder_ty: binder_ty.iter().cloned().collect(), formula, }); // Descend into the binder type for nested refinements. - scan_type(&binder.tokens, steps, out)?; + scan_type(&binder_ty, steps, out)?; return Ok(()); } } @@ -293,16 +286,37 @@ fn scan_type( Ok(()) } -/// Splits `{ binder : ty | formula }` into its binder and formula expression. -fn split_refinement(stream: TokenStream2) -> syn::Result<(NamedType, TokenStream2)> { +// --------------------------------------------------------------------------- +// Token-level parsers. +// --------------------------------------------------------------------------- + +/// Parses `name : ` (a parameter binding or a refinement +/// binder). Returns the bound name and the remaining type tokens. +fn parse_name_typed_binding(tokens: &[TokenTree2]) -> syn::Result<(syn::Ident, Vec)> { + let name = match tokens.first() { + Some(TokenTree2::Ident(id)) => id.clone(), + _ => return Err(err_tokens(tokens, "expected a parameter name")), + }; + match tokens.get(1) { + Some(TokenTree2::Punct(p)) if p.as_char() == ':' => {} + _ => return Err(err_tokens(tokens, "expected `:` after parameter name")), + } + Ok((name, tokens[2..].to_vec())) +} + +/// Parses `{ binder : ty | formula }` into the binder name, the binder type +/// tokens, and the formula expression. +fn parse_refinement( + stream: TokenStream2, +) -> syn::Result<(syn::Ident, Vec, TokenStream2)> { let toks: Vec = stream.into_iter().collect(); let bar = toks .iter() .position(|tt| matches!(tt, TokenTree2::Punct(p) if p.as_char() == '|')) .ok_or_else(|| err_tokens(&toks, "refinement type must contain `|`"))?; - let binder = split_name_type(&toks[..bar])?; + let (binder, binder_ty) = parse_name_typed_binding(&toks[..bar])?; let formula: TokenStream2 = toks[bar + 1..].iter().cloned().collect(); - Ok((binder, formula)) + Ok((binder, binder_ty, formula)) } /// Splits the tokens following an opening `<` at top level by commas, stopping @@ -373,6 +387,20 @@ fn err_tokens(tokens: &[TokenTree2], msg: &str) -> syn::Error { syn::Error::new_spanned(stream, msg) } +fn param_index(func: &FnItemWithSignature, name: &syn::Ident) -> syn::Result { + let pos = func.sig().inputs.iter().position(|arg| match arg { + FnArg::Receiver(_) => name == "self", + FnArg::Typed(pt) => matches!(&*pt.pat, syn::Pat::Ident(pi) if &pi.ident == name), + }); + pos.ok_or_else(|| { + syn::Error::new_spanned(name, format!("no parameter named `{}` in signature", name)) + }) +} + +// --------------------------------------------------------------------------- +// Codegen helpers for a single [`Refinement`]. +// --------------------------------------------------------------------------- + fn rewrite_self_in_tokens(tokens: TokenStream2) -> TokenStream2 { tokens .into_iter() @@ -387,7 +415,7 @@ fn rewrite_self_in_tokens(tokens: TokenStream2) -> TokenStream2 { .collect() } -fn refine_fn_name(func: &FnItemWithSignature, r: &Refinement) -> syn::Ident { +fn formula_fn_name(func: &FnItemWithSignature, r: &Refinement) -> syn::Ident { let pos = r .steps .iter() @@ -401,12 +429,15 @@ fn refine_fn_name(func: &FnItemWithSignature, r: &Refinement) -> syn::Ident { format_ident!("_thrust_refine_{}_{}", func.sig().ident, pos) } -fn refine_formula_fn( +/// Builds the `#[thrust::formula_fn]` declaration that carries the refinement's +/// formula. Its first parameter is the refinement's value binder; the rest are +/// the enclosing function's parameters in model-typed form. +fn build_formula_fn( func: &FnItemWithSignature, outer_context: Option<&FnOuterItem>, r: &Refinement, ) -> TokenStream2 { - let name = refine_fn_name(func, r); + let name = formula_fn_name(func, r); let def_generics = generic_params_tokens(&func.sig().generics); let model_params = fn_params_with_model_ty(&func.sig().inputs); let model_preds = model_where_predicates(func, outer_context); @@ -428,8 +459,10 @@ fn refine_formula_fn( } } -fn refine_path_stmt(func: &FnItemWithSignature, r: &Refinement) -> TokenStream2 { - let name = refine_fn_name(func, r); +/// Builds the `#[thrust::refinement_path(..)] ;` path statement that links +/// the generated formula_fn to the refinement's [`TypePositionStep`] path. +fn build_refinement_path_stmt(func: &FnItemWithSignature, r: &Refinement) -> TokenStream2 { + let name = formula_fn_name(func, r); let turbofish = generic_turbofish(&func.sig().generics); let path_prefix = if func.sig().receiver().is_some() { quote!(Self::) From 71b65ba47942f51bea85cc95f51a6a817a074bff Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 2 Jun 2026 15:52:57 +0000 Subject: [PATCH 14/18] Split Refinement into RefinedType and RefinedTypeAnnotation - RefinedType models the `{ binder : ty | formula }` syntax fragment. - RefinedTypeAnnotation is a RefinedType paired with its position in the function signature (where it applies). Rename collect_*_refinements to collect_*_annotations and expand_with_refinements to expand_with_annotations so the function names line up with the type they handle. Update parse_refinement to parse_refined_type, returning a RefinedType directly. https://claude.ai/code/session_01Km9xwaVaGQjnAy1iHsPewa --- thrust-macros/src/rty.rs | 173 +++++++++++++++++++++------------------ 1 file changed, 93 insertions(+), 80 deletions(-) diff --git a/thrust-macros/src/rty.rs b/thrust-macros/src/rty.rs index af2223b1..3cdfbc94 100644 --- a/thrust-macros/src/rty.rs +++ b/thrust-macros/src/rty.rs @@ -31,12 +31,10 @@ enum TypePositionStep { TypeArg(usize), } -/// A refinement node `{ binder : ty | formula }` extracted from a type -/// expression, together with its position within the function signature. +/// A refinement type `{ binder : ty | formula }` parsed from a type +/// expression in a `param` / `ret` / `sig` annotation. #[derive(Clone)] -struct Refinement { - /// Path from the function root to the refined sub-type. - steps: Vec, +struct RefinedType { /// The bound name (the `v` in `{ v: T | phi }`). binder: syn::Ident, /// Tokens of the binder's type (`T` in `{ v: T | phi }`); kept as tokens @@ -46,6 +44,15 @@ struct Refinement { formula: TokenStream2, } +/// A [`RefinedType`] together with the position within the function signature +/// at which it applies. +#[derive(Clone)] +struct RefinedTypeAnnotation { + /// Path from the function root to the sub-type where `refined_type` applies. + position: Vec, + refined_type: RefinedType, +} + // --------------------------------------------------------------------------- // Macro entry points (called by lib.rs). // --------------------------------------------------------------------------- @@ -57,11 +64,11 @@ pub(crate) fn expand_param(attr: TokenStream, item: TokenStream) -> TokenStream Err(e) => return emit_error(e, &func), }; let attr_tokens: Vec = TokenStream2::from(attr).into_iter().collect(); - let refinements = match collect_param_refinements(&func, &attr_tokens) { - Ok(r) => r, + let annotations = match collect_param_annotations(&func, &attr_tokens) { + Ok(a) => a, Err(e) => return emit_error(e, &func), }; - expand_with_refinements(func, outer_context, refinements) + expand_with_annotations(func, outer_context, annotations) } pub(crate) fn expand_ret(attr: TokenStream, item: TokenStream) -> TokenStream { @@ -71,11 +78,11 @@ pub(crate) fn expand_ret(attr: TokenStream, item: TokenStream) -> TokenStream { Err(e) => return emit_error(e, &func), }; let attr_tokens: Vec = TokenStream2::from(attr).into_iter().collect(); - let refinements = match collect_ret_refinements(&attr_tokens) { - Ok(r) => r, + let annotations = match collect_ret_annotations(&attr_tokens) { + Ok(a) => a, Err(e) => return emit_error(e, &func), }; - expand_with_refinements(func, outer_context, refinements) + expand_with_annotations(func, outer_context, annotations) } pub(crate) fn expand_sig(attr: TokenStream, item: TokenStream) -> TokenStream { @@ -85,56 +92,56 @@ pub(crate) fn expand_sig(attr: TokenStream, item: TokenStream) -> TokenStream { Err(e) => return emit_error(e, &func), }; let attr_tokens: Vec = TokenStream2::from(attr).into_iter().collect(); - let refinements = match collect_sig_refinements(&func, &attr_tokens) { - Ok(r) => r, + let annotations = match collect_sig_annotations(&func, &attr_tokens) { + Ok(a) => a, Err(e) => return emit_error(e, &func), }; - expand_with_refinements(func, outer_context, refinements) + expand_with_annotations(func, outer_context, annotations) } // --------------------------------------------------------------------------- -// Per-annotation refinement collectors. +// Per-annotation collectors. // -// Each parses one specific attribute form and extracts the refinement nodes -// contained in its type expression(s). +// Each parses one specific attribute form and extracts the +// [`RefinedTypeAnnotation`]s contained in its type expression(s). // --------------------------------------------------------------------------- -/// Parses a `#[param(name: ty)]` attribute and extracts the refinements in -/// `ty`, anchored at the position of the parameter `name`. -fn collect_param_refinements( +/// Parses a `#[param(name: ty)]` attribute and extracts the refined-type +/// annotations from `ty`, anchored at the position of the parameter `name`. +fn collect_param_annotations( func: &FnItemWithSignature, attr_tokens: &[TokenTree2], -) -> syn::Result> { +) -> syn::Result> { let (name, ty_tokens) = parse_name_typed_binding(attr_tokens)?; let idx = param_index(func, &name)?; - let mut refinements = Vec::new(); + let mut annotations = Vec::new(); scan_type( &ty_tokens, vec![TypePositionStep::Param(idx)], - &mut refinements, + &mut annotations, )?; - Ok(refinements) + Ok(annotations) } -/// Parses a `#[ret(ty)]` attribute and extracts the refinements in `ty`, -/// anchored at the return position. -fn collect_ret_refinements(attr_tokens: &[TokenTree2]) -> syn::Result> { - let mut refinements = Vec::new(); +/// Parses a `#[ret(ty)]` attribute and extracts the refined-type annotations +/// from `ty`, anchored at the return position. +fn collect_ret_annotations(attr_tokens: &[TokenTree2]) -> syn::Result> { + let mut annotations = Vec::new(); scan_type( attr_tokens, vec![TypePositionStep::Return], - &mut refinements, + &mut annotations, )?; - Ok(refinements) + Ok(annotations) } /// Parses a `#[sig(fn(n0: t0, ..) -> r)]` attribute and extracts the -/// refinements in every parameter type and the return type, each anchored at -/// the corresponding position. -fn collect_sig_refinements( +/// refined-type annotations from every parameter type and the return type, +/// each anchored at the corresponding position. +fn collect_sig_annotations( func: &FnItemWithSignature, attr_tokens: &[TokenTree2], -) -> syn::Result> { +) -> syn::Result> { match attr_tokens.first() { Some(TokenTree2::Ident(id)) if id == "fn" => {} _ => return Err(err_tokens(attr_tokens, "expected `fn` in sig annotation")), @@ -144,7 +151,7 @@ fn collect_sig_refinements( _ => return Err(err_tokens(attr_tokens, "expected `(..)` after `fn`")), }; - let mut refinements = Vec::new(); + let mut annotations = Vec::new(); let arg_tokens: Vec = arg_group.stream().into_iter().collect(); for arg in split_top_level_commas(&arg_tokens) { if arg.is_empty() { @@ -155,7 +162,7 @@ fn collect_sig_refinements( scan_type( &ty_tokens, vec![TypePositionStep::Param(idx)], - &mut refinements, + &mut annotations, )?; } @@ -173,33 +180,34 @@ fn collect_sig_refinements( )) } }; - scan_type(ret_tokens, vec![TypePositionStep::Return], &mut refinements)?; - Ok(refinements) + scan_type(ret_tokens, vec![TypePositionStep::Return], &mut annotations)?; + Ok(annotations) } // --------------------------------------------------------------------------- -// Macro-output assembly: from the collected refinements, produce the +// Macro-output assembly: from the collected annotations, produce the // formula_fn declarations and the path-statement-injected function. // --------------------------------------------------------------------------- -fn expand_with_refinements( +fn expand_with_annotations( mut func: FnItemWithSignature, outer_context: Option, - refinements: Vec, + annotations: Vec, ) -> TokenStream { - if refinements.is_empty() { + if annotations.is_empty() { return func.into_token_stream().into(); } let has_receiver = func.sig().receiver().is_some(); let mut formula_fns = Vec::new(); let mut path_stmts = Vec::new(); - for mut r in refinements { + for mut annotation in annotations { if has_receiver { - r.formula = rewrite_self_in_tokens(r.formula); + annotation.refined_type.formula = + rewrite_self_in_tokens(annotation.refined_type.formula); } - formula_fns.push(build_formula_fn(&func, outer_context.as_ref(), &r)); - path_stmts.push(build_refinement_path_stmt(&func, &r)); + formula_fns.push(build_formula_fn(&func, outer_context.as_ref(), &annotation)); + path_stmts.push(build_refinement_path_stmt(&func, &annotation)); } let Some(block) = func.block_mut() else { @@ -235,31 +243,31 @@ fn emit_error(err: syn::Error, func: &FnItemWithSignature) -> TokenStream { // recording every refinement node and the path that locates it. // --------------------------------------------------------------------------- -/// Walks the type expression `tokens` and appends any refinement nodes it -/// contains to `out`. `steps` is the path to `tokens` itself; descending into a -/// generic type argument extends it with a [`TypePositionStep::TypeArg`]. +/// Walks the type expression `tokens` and appends any [`RefinedType`] nodes it +/// contains to `out` as [`RefinedTypeAnnotation`]s. `steps` is the path to +/// `tokens` itself; descending into a generic type argument extends it with a +/// [`TypePositionStep::TypeArg`]. fn scan_type( tokens: &[TokenTree2], steps: Vec, - out: &mut Vec, + out: &mut Vec, ) -> syn::Result<()> { if tokens.is_empty() { return Ok(()); } - // A refinement node is exactly a brace-delimited group. + // A refinement type is exactly a brace-delimited group. if tokens.len() == 1 { if let TokenTree2::Group(g) = &tokens[0] { if g.delimiter() == proc_macro2::Delimiter::Brace { - let (binder, binder_ty, formula) = parse_refinement(g.stream())?; - out.push(Refinement { - steps: steps.clone(), - binder, - binder_ty: binder_ty.iter().cloned().collect(), - formula, + let refined_type = parse_refined_type(g.stream())?; + let nested: Vec = refined_type.binder_ty.clone().into_iter().collect(); + out.push(RefinedTypeAnnotation { + position: steps.clone(), + refined_type, }); // Descend into the binder type for nested refinements. - scan_type(&binder_ty, steps, out)?; + scan_type(&nested, steps, out)?; return Ok(()); } } @@ -304,19 +312,21 @@ fn parse_name_typed_binding(tokens: &[TokenTree2]) -> syn::Result<(syn::Ident, V Ok((name, tokens[2..].to_vec())) } -/// Parses `{ binder : ty | formula }` into the binder name, the binder type -/// tokens, and the formula expression. -fn parse_refinement( - stream: TokenStream2, -) -> syn::Result<(syn::Ident, Vec, TokenStream2)> { +/// Parses the contents of a `{ binder : ty | formula }` brace group into a +/// [`RefinedType`]. +fn parse_refined_type(stream: TokenStream2) -> syn::Result { let toks: Vec = stream.into_iter().collect(); let bar = toks .iter() .position(|tt| matches!(tt, TokenTree2::Punct(p) if p.as_char() == '|')) .ok_or_else(|| err_tokens(&toks, "refinement type must contain `|`"))?; - let (binder, binder_ty) = parse_name_typed_binding(&toks[..bar])?; + let (binder, binder_ty_tokens) = parse_name_typed_binding(&toks[..bar])?; let formula: TokenStream2 = toks[bar + 1..].iter().cloned().collect(); - Ok((binder, binder_ty, formula)) + Ok(RefinedType { + binder, + binder_ty: binder_ty_tokens.iter().cloned().collect(), + formula, + }) } /// Splits the tokens following an opening `<` at top level by commas, stopping @@ -415,9 +425,9 @@ fn rewrite_self_in_tokens(tokens: TokenStream2) -> TokenStream2 { .collect() } -fn formula_fn_name(func: &FnItemWithSignature, r: &Refinement) -> syn::Ident { - let pos = r - .steps +fn formula_fn_name(func: &FnItemWithSignature, ann: &RefinedTypeAnnotation) -> syn::Ident { + let pos = ann + .position .iter() .map(|s| match s { TypePositionStep::Param(i) => format!("p{}", i), @@ -429,22 +439,22 @@ fn formula_fn_name(func: &FnItemWithSignature, r: &Refinement) -> syn::Ident { format_ident!("_thrust_refine_{}_{}", func.sig().ident, pos) } -/// Builds the `#[thrust::formula_fn]` declaration that carries the refinement's -/// formula. Its first parameter is the refinement's value binder; the rest are -/// the enclosing function's parameters in model-typed form. +/// Builds the `#[thrust::formula_fn]` declaration that carries the annotation's +/// formula. Its first parameter is the refined type's value binder; the rest +/// are the enclosing function's parameters in model-typed form. fn build_formula_fn( func: &FnItemWithSignature, outer_context: Option<&FnOuterItem>, - r: &Refinement, + ann: &RefinedTypeAnnotation, ) -> TokenStream2 { - let name = formula_fn_name(func, r); + let name = formula_fn_name(func, ann); let def_generics = generic_params_tokens(&func.sig().generics); let model_params = fn_params_with_model_ty(&func.sig().inputs); let model_preds = model_where_predicates(func, outer_context); let extended_where = extended_where_clause(func, &model_preds); - let binder = &r.binder; - let binder_ty = &r.binder_ty; - let formula = &r.formula; + let binder = &ann.refined_type.binder; + let binder_ty = &ann.refined_type.binder_ty; + let formula = &ann.refined_type.formula; quote! { #[allow(unused_variables)] @@ -460,16 +470,19 @@ fn build_formula_fn( } /// Builds the `#[thrust::refinement_path(..)] ;` path statement that links -/// the generated formula_fn to the refinement's [`TypePositionStep`] path. -fn build_refinement_path_stmt(func: &FnItemWithSignature, r: &Refinement) -> TokenStream2 { - let name = formula_fn_name(func, r); +/// the generated formula_fn to the annotation's position. +fn build_refinement_path_stmt( + func: &FnItemWithSignature, + ann: &RefinedTypeAnnotation, +) -> TokenStream2 { + let name = formula_fn_name(func, ann); let turbofish = generic_turbofish(&func.sig().generics); let path_prefix = if func.sig().receiver().is_some() { quote!(Self::) } else { quote!() }; - let encoded_steps = r.steps.iter().map(|s| match s { + let encoded_steps = ann.position.iter().map(|s| match s { TypePositionStep::Param(i) => { let lit = proc_macro2::Literal::usize_unsuffixed(*i); quote!($#lit) From 03336f2df912b3d9988d7d1b938596ac9f6b5e99 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 2 Jun 2026 15:57:26 +0000 Subject: [PATCH 15/18] Replace scan_type with parse_refined_type_annotations Reframe the type-expression walker so its name and signature tell you what it does: it takes type-expression tokens and returns the refined-type annotations they contain, with positions relative to the input. No mutable accumulator, no implicit root-step parameter. Callers add the root step via a small anchor_at helper that prepends a TypePositionStep to every annotation's position; nested recursion inside parse_refined_type_annotations uses the same helper to attribute TypeArg steps for each generic argument. https://claude.ai/code/session_01Km9xwaVaGQjnAy1iHsPewa --- thrust-macros/src/rty.rs | 94 +++++++++++++++++++++------------------- 1 file changed, 50 insertions(+), 44 deletions(-) diff --git a/thrust-macros/src/rty.rs b/thrust-macros/src/rty.rs index 3cdfbc94..698ae7c9 100644 --- a/thrust-macros/src/rty.rs +++ b/thrust-macros/src/rty.rs @@ -114,25 +114,15 @@ fn collect_param_annotations( ) -> syn::Result> { let (name, ty_tokens) = parse_name_typed_binding(attr_tokens)?; let idx = param_index(func, &name)?; - let mut annotations = Vec::new(); - scan_type( - &ty_tokens, - vec![TypePositionStep::Param(idx)], - &mut annotations, - )?; - Ok(annotations) + let annotations = parse_refined_type_annotations(&ty_tokens)?; + Ok(anchor_at(annotations, TypePositionStep::Param(idx))) } /// Parses a `#[ret(ty)]` attribute and extracts the refined-type annotations /// from `ty`, anchored at the return position. fn collect_ret_annotations(attr_tokens: &[TokenTree2]) -> syn::Result> { - let mut annotations = Vec::new(); - scan_type( - attr_tokens, - vec![TypePositionStep::Return], - &mut annotations, - )?; - Ok(annotations) + let annotations = parse_refined_type_annotations(attr_tokens)?; + Ok(anchor_at(annotations, TypePositionStep::Return)) } /// Parses a `#[sig(fn(n0: t0, ..) -> r)]` attribute and extracts the @@ -159,11 +149,10 @@ fn collect_sig_annotations( } let (name, ty_tokens) = parse_name_typed_binding(&arg)?; let idx = param_index(func, &name)?; - scan_type( - &ty_tokens, - vec![TypePositionStep::Param(idx)], - &mut annotations, - )?; + annotations.extend(anchor_at( + parse_refined_type_annotations(&ty_tokens)?, + TypePositionStep::Param(idx), + )); } let rest = &attr_tokens[2..]; @@ -180,7 +169,10 @@ fn collect_sig_annotations( )) } }; - scan_type(ret_tokens, vec![TypePositionStep::Return], &mut annotations)?; + annotations.extend(anchor_at( + parse_refined_type_annotations(ret_tokens)?, + TypePositionStep::Return, + )); Ok(annotations) } @@ -239,59 +231,73 @@ fn emit_error(err: syn::Error, func: &FnItemWithSignature) -> TokenStream { } // --------------------------------------------------------------------------- -// Type-expression scanner: walks a type expression in `name: ty` / `ty` form, -// recording every refinement node and the path that locates it. +// Type-expression parser: extracts all refined-type annotations from a type +// expression's tokens. // --------------------------------------------------------------------------- -/// Walks the type expression `tokens` and appends any [`RefinedType`] nodes it -/// contains to `out` as [`RefinedTypeAnnotation`]s. `steps` is the path to -/// `tokens` itself; descending into a generic type argument extends it with a -/// [`TypePositionStep::TypeArg`]. -fn scan_type( +/// Parses the type expression in `tokens` and returns every refined-type +/// annotation it contains, with each annotation's `position` given relative to +/// the start of `tokens` (no root step prepended). +fn parse_refined_type_annotations( tokens: &[TokenTree2], - steps: Vec, - out: &mut Vec, -) -> syn::Result<()> { +) -> syn::Result> { if tokens.is_empty() { - return Ok(()); + return Ok(Vec::new()); } - // A refinement type is exactly a brace-delimited group. + // A refinement type is exactly a brace-delimited group. The annotation sits + // at the current position; further annotations may be nested inside the + // binder type (which lives at the same position). if tokens.len() == 1 { if let TokenTree2::Group(g) = &tokens[0] { if g.delimiter() == proc_macro2::Delimiter::Brace { let refined_type = parse_refined_type(g.stream())?; - let nested: Vec = refined_type.binder_ty.clone().into_iter().collect(); - out.push(RefinedTypeAnnotation { - position: steps.clone(), + let binder_ty_tokens: Vec = + refined_type.binder_ty.clone().into_iter().collect(); + let mut out = vec![RefinedTypeAnnotation { + position: Vec::new(), refined_type, - }); - // Descend into the binder type for nested refinements. - scan_type(&nested, steps, out)?; - return Ok(()); + }]; + out.extend(parse_refined_type_annotations(&binder_ty_tokens)?); + return Ok(out); } } } - // A nominal type `Name` (`Box` included). + // A nominal type `Name` (`Box` included): each generic + // argument sits one `TypeArg(i)` step deeper. if let TokenTree2::Ident(_) = &tokens[0] { if let Some(TokenTree2::Punct(p)) = tokens.get(1) { if p.as_char() == '<' { + let mut out = Vec::new(); let mut type_idx = 0; for arg in split_angle_args(&tokens[2..]) { if is_lifetime(&arg) { continue; } - let mut child = steps.clone(); - child.push(TypePositionStep::TypeArg(type_idx)); - scan_type(&arg, child, out)?; + out.extend(anchor_at( + parse_refined_type_annotations(&arg)?, + TypePositionStep::TypeArg(type_idx), + )); type_idx += 1; } + return Ok(out); } } } - Ok(()) + Ok(Vec::new()) +} + +/// Returns `annotations` with `root` prepended to every annotation's position. +fn anchor_at( + mut annotations: Vec, + root: TypePositionStep, +) -> Vec { + for annotation in &mut annotations { + annotation.position.insert(0, root); + } + annotations } // --------------------------------------------------------------------------- From d115126bd6552c91a5940cac03ccbe4c7aaffb26 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 2 Jun 2026 16:15:22 +0000 Subject: [PATCH 16/18] Inline emit_error helper, trim verbose comments Drop emit_error and inline its two-line body at each call site, and shorten the module-level doc, the type and function doc comments, and the inline comments in parse_refined_type_annotations. Also remove the section-banner separator comments. https://claude.ai/code/session_01Km9xwaVaGQjnAy1iHsPewa --- thrust-macros/src/rty.rs | 138 ++++++++++++++------------------------- 1 file changed, 50 insertions(+), 88 deletions(-) diff --git a/thrust-macros/src/rty.rs b/thrust-macros/src/rty.rs index 698ae7c9..22537dcb 100644 --- a/thrust-macros/src/rty.rs +++ b/thrust-macros/src/rty.rs @@ -1,13 +1,10 @@ //! Refinement-type annotations: `param`, `ret`, `sig`. //! -//! These lower refinement types (e.g. `List<{ v: i32 | v > 0 }>`) into -//! `#[thrust::formula_fn]`s plus positioned `#[thrust::refinement_path(..)]` -//! path statements injected into the function body. The "type position" -//! addresses into the function type: a parameter (`$i`) or the return (the -//! `result` keyword) selects a function slot, and bare integer steps (`i`) -//! descend into generic arguments (enum args / `Box` pointee). For example, -//! `#[thrust::refinement_path(result, 0)]` is the first type-argument of the -//! return. +//! Lowers refinement types (e.g. `List<{ v: i32 | v > 0 }>`) into +//! `#[thrust::formula_fn]`s plus `#[thrust::refinement_path(..)]` path +//! statements injected into the function body. The path identifies the +//! sub-type the formula applies to: `$i` selects the i-th parameter, `result` +//! the return slot, and a bare integer `i` the i-th generic argument. use proc_macro::TokenStream; use proc_macro2::{TokenStream as TokenStream2, TokenTree as TokenTree2}; @@ -20,10 +17,8 @@ use super::{ }; /// One step in a refinement's type-position path; mirrors the plugin's -/// `rty::TypePositionStep`. -/// -/// The attribute encoding emitted into `#[thrust::refinement_path(..)]` is: -/// `Param(i)` → `$i`, `Return` → `result`, `TypeArg(i)` → a bare integer `i`. +/// `rty::TypePositionStep`. Encoded in `#[thrust::refinement_path(..)]` as +/// `Param(i)` → `$i`, `Return` → `result`, `TypeArg(i)` → `i`. #[derive(Clone, Copy)] enum TypePositionStep { Param(usize), @@ -31,42 +26,39 @@ enum TypePositionStep { TypeArg(usize), } -/// A refinement type `{ binder : ty | formula }` parsed from a type -/// expression in a `param` / `ret` / `sig` annotation. +/// A `{ binder : ty | formula }` parsed from a type expression. #[derive(Clone)] struct RefinedType { - /// The bound name (the `v` in `{ v: T | phi }`). binder: syn::Ident, - /// Tokens of the binder's type (`T` in `{ v: T | phi }`); kept as tokens - /// because nested refinement braces make it not always a valid Rust type. + /// Kept as tokens because nested refinements make it not always a valid + /// Rust type. binder_ty: TokenStream2, - /// Tokens of the refinement formula expression (`phi` in `{ v: T | phi }`). formula: TokenStream2, } -/// A [`RefinedType`] together with the position within the function signature -/// at which it applies. +/// A [`RefinedType`] together with the position where it applies. #[derive(Clone)] struct RefinedTypeAnnotation { - /// Path from the function root to the sub-type where `refined_type` applies. position: Vec, refined_type: RefinedType, } -// --------------------------------------------------------------------------- -// Macro entry points (called by lib.rs). -// --------------------------------------------------------------------------- - pub(crate) fn expand_param(attr: TokenStream, item: TokenStream) -> TokenStream { let func = parse_macro_input!(item as FnItemWithSignature); let outer_context = match extract_outer_context(&func) { Ok(ctx) => ctx, - Err(e) => return emit_error(e, &func), + Err(e) => { + let err = e.to_compile_error(); + return quote! { #err #func }.into(); + } }; let attr_tokens: Vec = TokenStream2::from(attr).into_iter().collect(); let annotations = match collect_param_annotations(&func, &attr_tokens) { Ok(a) => a, - Err(e) => return emit_error(e, &func), + Err(e) => { + let err = e.to_compile_error(); + return quote! { #err #func }.into(); + } }; expand_with_annotations(func, outer_context, annotations) } @@ -75,12 +67,18 @@ pub(crate) fn expand_ret(attr: TokenStream, item: TokenStream) -> TokenStream { let func = parse_macro_input!(item as FnItemWithSignature); let outer_context = match extract_outer_context(&func) { Ok(ctx) => ctx, - Err(e) => return emit_error(e, &func), + Err(e) => { + let err = e.to_compile_error(); + return quote! { #err #func }.into(); + } }; let attr_tokens: Vec = TokenStream2::from(attr).into_iter().collect(); let annotations = match collect_ret_annotations(&attr_tokens) { Ok(a) => a, - Err(e) => return emit_error(e, &func), + Err(e) => { + let err = e.to_compile_error(); + return quote! { #err #func }.into(); + } }; expand_with_annotations(func, outer_context, annotations) } @@ -89,25 +87,22 @@ pub(crate) fn expand_sig(attr: TokenStream, item: TokenStream) -> TokenStream { let func = parse_macro_input!(item as FnItemWithSignature); let outer_context = match extract_outer_context(&func) { Ok(ctx) => ctx, - Err(e) => return emit_error(e, &func), + Err(e) => { + let err = e.to_compile_error(); + return quote! { #err #func }.into(); + } }; let attr_tokens: Vec = TokenStream2::from(attr).into_iter().collect(); let annotations = match collect_sig_annotations(&func, &attr_tokens) { Ok(a) => a, - Err(e) => return emit_error(e, &func), + Err(e) => { + let err = e.to_compile_error(); + return quote! { #err #func }.into(); + } }; expand_with_annotations(func, outer_context, annotations) } -// --------------------------------------------------------------------------- -// Per-annotation collectors. -// -// Each parses one specific attribute form and extracts the -// [`RefinedTypeAnnotation`]s contained in its type expression(s). -// --------------------------------------------------------------------------- - -/// Parses a `#[param(name: ty)]` attribute and extracts the refined-type -/// annotations from `ty`, anchored at the position of the parameter `name`. fn collect_param_annotations( func: &FnItemWithSignature, attr_tokens: &[TokenTree2], @@ -118,16 +113,11 @@ fn collect_param_annotations( Ok(anchor_at(annotations, TypePositionStep::Param(idx))) } -/// Parses a `#[ret(ty)]` attribute and extracts the refined-type annotations -/// from `ty`, anchored at the return position. fn collect_ret_annotations(attr_tokens: &[TokenTree2]) -> syn::Result> { let annotations = parse_refined_type_annotations(attr_tokens)?; Ok(anchor_at(annotations, TypePositionStep::Return)) } -/// Parses a `#[sig(fn(n0: t0, ..) -> r)]` attribute and extracts the -/// refined-type annotations from every parameter type and the return type, -/// each anchored at the corresponding position. fn collect_sig_annotations( func: &FnItemWithSignature, attr_tokens: &[TokenTree2], @@ -176,11 +166,6 @@ fn collect_sig_annotations( Ok(annotations) } -// --------------------------------------------------------------------------- -// Macro-output assembly: from the collected annotations, produce the -// formula_fn declarations and the path-statement-injected function. -// --------------------------------------------------------------------------- - fn expand_with_annotations( mut func: FnItemWithSignature, outer_context: Option, @@ -225,19 +210,8 @@ fn expand_with_annotations( .into() } -fn emit_error(err: syn::Error, func: &FnItemWithSignature) -> TokenStream { - let err = err.to_compile_error(); - quote! { #err #func }.into() -} - -// --------------------------------------------------------------------------- -// Type-expression parser: extracts all refined-type annotations from a type -// expression's tokens. -// --------------------------------------------------------------------------- - -/// Parses the type expression in `tokens` and returns every refined-type -/// annotation it contains, with each annotation's `position` given relative to -/// the start of `tokens` (no root step prepended). +/// Returns every refined-type annotation contained in `tokens`, with each +/// annotation's position relative to the start of `tokens`. fn parse_refined_type_annotations( tokens: &[TokenTree2], ) -> syn::Result> { @@ -245,9 +219,8 @@ fn parse_refined_type_annotations( return Ok(Vec::new()); } - // A refinement type is exactly a brace-delimited group. The annotation sits - // at the current position; further annotations may be nested inside the - // binder type (which lives at the same position). + // A refinement type is a brace-delimited group; nested annotations in its + // binder type sit at the same position. if tokens.len() == 1 { if let TokenTree2::Group(g) = &tokens[0] { if g.delimiter() == proc_macro2::Delimiter::Brace { @@ -264,8 +237,7 @@ fn parse_refined_type_annotations( } } - // A nominal type `Name` (`Box` included): each generic - // argument sits one `TypeArg(i)` step deeper. + // A nominal type `Name`: each arg sits one `TypeArg(i)` deeper. if let TokenTree2::Ident(_) = &tokens[0] { if let Some(TokenTree2::Punct(p)) = tokens.get(1) { if p.as_char() == '<' { @@ -289,7 +261,7 @@ fn parse_refined_type_annotations( Ok(Vec::new()) } -/// Returns `annotations` with `root` prepended to every annotation's position. +/// Prepends `root` to every annotation's position. fn anchor_at( mut annotations: Vec, root: TypePositionStep, @@ -300,12 +272,7 @@ fn anchor_at( annotations } -// --------------------------------------------------------------------------- -// Token-level parsers. -// --------------------------------------------------------------------------- - -/// Parses `name : ` (a parameter binding or a refinement -/// binder). Returns the bound name and the remaining type tokens. +/// Parses `name : ` into the bound name and the type tokens. fn parse_name_typed_binding(tokens: &[TokenTree2]) -> syn::Result<(syn::Ident, Vec)> { let name = match tokens.first() { Some(TokenTree2::Ident(id)) => id.clone(), @@ -318,8 +285,7 @@ fn parse_name_typed_binding(tokens: &[TokenTree2]) -> syn::Result<(syn::Ident, V Ok((name, tokens[2..].to_vec())) } -/// Parses the contents of a `{ binder : ty | formula }` brace group into a -/// [`RefinedType`]. +/// Parses the contents of a `{ binder : ty | formula }` brace group. fn parse_refined_type(stream: TokenStream2) -> syn::Result { let toks: Vec = stream.into_iter().collect(); let bar = toks @@ -335,8 +301,8 @@ fn parse_refined_type(stream: TokenStream2) -> syn::Result { }) } -/// Splits the tokens following an opening `<` at top level by commas, stopping -/// at the matching `>`. +/// Splits tokens after an opening `<` by top-level commas, stopping at the +/// matching `>`. fn split_angle_args(tokens: &[TokenTree2]) -> Vec> { let mut args = Vec::new(); let mut cur = Vec::new(); @@ -413,10 +379,6 @@ fn param_index(func: &FnItemWithSignature, name: &syn::Ident) -> syn::Result TokenStream2 { tokens .into_iter() @@ -445,9 +407,9 @@ fn formula_fn_name(func: &FnItemWithSignature, ann: &RefinedTypeAnnotation) -> s format_ident!("_thrust_refine_{}_{}", func.sig().ident, pos) } -/// Builds the `#[thrust::formula_fn]` declaration that carries the annotation's -/// formula. Its first parameter is the refined type's value binder; the rest -/// are the enclosing function's parameters in model-typed form. +/// The `#[thrust::formula_fn]` declaration carrying the annotation's formula. +/// Its first parameter is the refined type's value binder; the rest are the +/// enclosing function's parameters in model-typed form. fn build_formula_fn( func: &FnItemWithSignature, outer_context: Option<&FnOuterItem>, @@ -475,8 +437,8 @@ fn build_formula_fn( } } -/// Builds the `#[thrust::refinement_path(..)] ;` path statement that links -/// the generated formula_fn to the annotation's position. +/// The `#[thrust::refinement_path(..)] ;` statement linking the generated +/// formula_fn to the annotation's position. fn build_refinement_path_stmt( func: &FnItemWithSignature, ann: &RefinedTypeAnnotation, From d65014e553601b873abb8447dff6d48093246763 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 2 Jun 2026 16:38:28 +0000 Subject: [PATCH 17/18] Handle paths/refs in refinement-type parser; strip nested refinements from binder type Address two reviewer findings in parse_refined_type_annotations: - It previously only recognized refined generics when the token stream started with `Ident <` directly, silently missing common spellings like `std::vec::Vec<{..}>`, `crate::Pair<{..}>`, and `&Vec<{..}>`. Walk through path segments to the final ident, and treat a leading `&` (with optional lifetime and `mut`) as a TypeArg(0) step into the pointee. - build_formula_fn used the binder type's tokens verbatim in a Rust type position (`<#binder_ty as Model>::Ty`), which is ill-formed when the binder type itself contains nested refinement braces such as `Pair<{ v: i32 | v > 0 }>`. The parser now stores binder_ty with every nested refinement brace replaced by the brace's binder type, yielding a valid Rust type at emission while still recording the nested annotations (at their own positions) separately. Inline the now-redundant parse_refined_type helper into the brace arm of parse_refined_type_annotations. Add pass/fail tests for each fix: - refine_param_path_qualified: `crate::Pair<{ v: i32 | v > 0 }>`. - refine_param_nested_binder: `{ q: Pair<{ v: i32 | v > 0 }> | true }`. https://claude.ai/code/session_01Km9xwaVaGQjnAy1iHsPewa --- tests/ui/fail/refine_param_nested_binder.rs | 19 +++ tests/ui/fail/refine_param_path_qualified.rs | 19 +++ tests/ui/pass/refine_param_nested_binder.rs | 19 +++ tests/ui/pass/refine_param_path_qualified.rs | 19 +++ thrust-macros/src/rty.rs | 124 ++++++++++++++----- 5 files changed, 172 insertions(+), 28 deletions(-) create mode 100644 tests/ui/fail/refine_param_nested_binder.rs create mode 100644 tests/ui/fail/refine_param_path_qualified.rs create mode 100644 tests/ui/pass/refine_param_nested_binder.rs create mode 100644 tests/ui/pass/refine_param_path_qualified.rs diff --git a/tests/ui/fail/refine_param_nested_binder.rs b/tests/ui/fail/refine_param_nested_binder.rs new file mode 100644 index 00000000..c75acfe9 --- /dev/null +++ b/tests/ui/fail/refine_param_nested_binder.rs @@ -0,0 +1,19 @@ +//@error-in-other-file: Unsat + +pub enum Pair { + Mk(T, T), +} + +impl thrust_models::Model for Pair { + type Ty = Self; +} + +#[thrust_macros::param(p: { q: Pair<{ v: i32 | v > 0 }> | true })] +#[thrust_macros::ret({ r: i32 | r > 100 })] +fn first(p: Pair) -> i32 { + match p { + Pair::Mk(a, _) => a, + } +} + +fn main() {} diff --git a/tests/ui/fail/refine_param_path_qualified.rs b/tests/ui/fail/refine_param_path_qualified.rs new file mode 100644 index 00000000..3bbf8170 --- /dev/null +++ b/tests/ui/fail/refine_param_path_qualified.rs @@ -0,0 +1,19 @@ +//@error-in-other-file: Unsat + +pub enum Pair { + Mk(T, T), +} + +impl thrust_models::Model for Pair { + type Ty = Self; +} + +#[thrust_macros::param(p: crate::Pair<{ v: i32 | v > 0 }>)] +#[thrust_macros::ret({ r: i32 | r > 100 })] +fn first(p: Pair) -> i32 { + match p { + Pair::Mk(a, _) => a, + } +} + +fn main() {} diff --git a/tests/ui/pass/refine_param_nested_binder.rs b/tests/ui/pass/refine_param_nested_binder.rs new file mode 100644 index 00000000..9c49e3d7 --- /dev/null +++ b/tests/ui/pass/refine_param_nested_binder.rs @@ -0,0 +1,19 @@ +//@check-pass + +pub enum Pair { + Mk(T, T), +} + +impl thrust_models::Model for Pair { + type Ty = Self; +} + +#[thrust_macros::param(p: { q: Pair<{ v: i32 | v > 0 }> | true })] +#[thrust_macros::ret({ r: i32 | r > 0 })] +fn first(p: Pair) -> i32 { + match p { + Pair::Mk(a, _) => a, + } +} + +fn main() {} diff --git a/tests/ui/pass/refine_param_path_qualified.rs b/tests/ui/pass/refine_param_path_qualified.rs new file mode 100644 index 00000000..ca8edd43 --- /dev/null +++ b/tests/ui/pass/refine_param_path_qualified.rs @@ -0,0 +1,19 @@ +//@check-pass + +pub enum Pair { + Mk(T, T), +} + +impl thrust_models::Model for Pair { + type Ty = Self; +} + +#[thrust_macros::param(p: crate::Pair<{ v: i32 | v > 0 }>)] +#[thrust_macros::ret({ r: i32 | r > 0 })] +fn first(p: Pair) -> i32 { + match p { + Pair::Mk(a, _) => a, + } +} + +fn main() {} diff --git a/thrust-macros/src/rty.rs b/thrust-macros/src/rty.rs index 22537dcb..f770e30d 100644 --- a/thrust-macros/src/rty.rs +++ b/thrust-macros/src/rty.rs @@ -220,30 +220,77 @@ fn parse_refined_type_annotations( } // A refinement type is a brace-delimited group; nested annotations in its - // binder type sit at the same position. + // binder type sit at the same position. The stored `binder_ty` has nested + // refinement braces stripped so it is a valid Rust type when emitted. if tokens.len() == 1 { if let TokenTree2::Group(g) = &tokens[0] { if g.delimiter() == proc_macro2::Delimiter::Brace { - let refined_type = parse_refined_type(g.stream())?; - let binder_ty_tokens: Vec = - refined_type.binder_ty.clone().into_iter().collect(); + let inner: Vec = g.stream().into_iter().collect(); + let bar = inner + .iter() + .position(|tt| matches!(tt, TokenTree2::Punct(p) if p.as_char() == '|')) + .ok_or_else(|| err_tokens(&inner, "refinement type must contain `|`"))?; + let (binder, binder_ty_tokens) = parse_name_typed_binding(&inner[..bar])?; + let formula: TokenStream2 = inner[bar + 1..].iter().cloned().collect(); + let nested = parse_refined_type_annotations(&binder_ty_tokens)?; + let stripped_binder_ty: TokenStream2 = strip_refinement_braces(&binder_ty_tokens) + .into_iter() + .collect(); let mut out = vec![RefinedTypeAnnotation { position: Vec::new(), - refined_type, + refined_type: RefinedType { + binder, + binder_ty: stripped_binder_ty, + formula, + }, }]; - out.extend(parse_refined_type_annotations(&binder_ty_tokens)?); + out.extend(nested); return Ok(out); } } } - // A nominal type `Name`: each arg sits one `TypeArg(i)` deeper. - if let TokenTree2::Ident(_) = &tokens[0] { - if let Some(TokenTree2::Punct(p)) = tokens.get(1) { - if p.as_char() == '<' { + // A reference `& [lifetime] [mut] T` is encoded by the plugin as a pointer; + // descend into the pointee at `TypeArg(0)`. + if let Some(TokenTree2::Punct(p)) = tokens.first() { + if p.as_char() == '&' { + let mut cursor = 1; + if let Some(TokenTree2::Punct(p)) = tokens.get(cursor) { + if p.as_char() == '\'' { + cursor += 1; + if matches!(tokens.get(cursor), Some(TokenTree2::Ident(_))) { + cursor += 1; + } + } + } + if matches!(tokens.get(cursor), Some(TokenTree2::Ident(id)) if id == "mut") { + cursor += 1; + } + return Ok(anchor_at( + parse_refined_type_annotations(&tokens[cursor..])?, + TypePositionStep::TypeArg(0), + )); + } + } + + // A path-qualified nominal type `path::to::Name`: walk past the path + // segments and recurse into each generic argument at `TypeArg(i)`. + let mut cursor = 0; + loop { + if !matches!(tokens.get(cursor), Some(TokenTree2::Ident(_))) { + return Ok(Vec::new()); + } + let next = cursor + 1; + match (tokens.get(next), tokens.get(next + 1)) { + (Some(TokenTree2::Punct(p1)), Some(TokenTree2::Punct(p2))) + if p1.as_char() == ':' && p2.as_char() == ':' => + { + cursor = next + 2; + } + (Some(TokenTree2::Punct(p)), _) if p.as_char() == '<' => { let mut out = Vec::new(); let mut type_idx = 0; - for arg in split_angle_args(&tokens[2..]) { + for arg in split_angle_args(&tokens[next + 1..]) { if is_lifetime(&arg) { continue; } @@ -255,10 +302,47 @@ fn parse_refined_type_annotations( } return Ok(out); } + _ => return Ok(Vec::new()), } } +} - Ok(Vec::new()) +/// Returns `tokens` with every refinement-type brace `{ b : ty | phi }` +/// replaced by its (recursively stripped) binder type `ty`, so the result is +/// a valid Rust type expression. +fn strip_refinement_braces(tokens: &[TokenTree2]) -> Vec { + let mut out = Vec::new(); + for tt in tokens { + match tt { + TokenTree2::Group(g) if g.delimiter() == proc_macro2::Delimiter::Brace => { + let inner: Vec = g.stream().into_iter().collect(); + let as_refinement = inner + .iter() + .position(|tt| matches!(tt, TokenTree2::Punct(p) if p.as_char() == '|')) + .and_then(|bar| parse_name_typed_binding(&inner[..bar]).ok()); + if let Some((_, ty_tokens)) = as_refinement { + out.extend(strip_refinement_braces(&ty_tokens)); + } else { + let stripped: TokenStream2 = + strip_refinement_braces(&inner).into_iter().collect(); + out.push(TokenTree2::Group(proc_macro2::Group::new( + g.delimiter(), + stripped, + ))); + } + } + TokenTree2::Group(g) => { + let inner: Vec = g.stream().into_iter().collect(); + let stripped: TokenStream2 = strip_refinement_braces(&inner).into_iter().collect(); + out.push(TokenTree2::Group(proc_macro2::Group::new( + g.delimiter(), + stripped, + ))); + } + other => out.push(other.clone()), + } + } + out } /// Prepends `root` to every annotation's position. @@ -285,22 +369,6 @@ fn parse_name_typed_binding(tokens: &[TokenTree2]) -> syn::Result<(syn::Ident, V Ok((name, tokens[2..].to_vec())) } -/// Parses the contents of a `{ binder : ty | formula }` brace group. -fn parse_refined_type(stream: TokenStream2) -> syn::Result { - let toks: Vec = stream.into_iter().collect(); - let bar = toks - .iter() - .position(|tt| matches!(tt, TokenTree2::Punct(p) if p.as_char() == '|')) - .ok_or_else(|| err_tokens(&toks, "refinement type must contain `|`"))?; - let (binder, binder_ty_tokens) = parse_name_typed_binding(&toks[..bar])?; - let formula: TokenStream2 = toks[bar + 1..].iter().cloned().collect(); - Ok(RefinedType { - binder, - binder_ty: binder_ty_tokens.iter().cloned().collect(), - formula, - }) -} - /// Splits tokens after an opening `<` by top-level commas, stopping at the /// matching `>`. fn split_angle_args(tokens: &[TokenTree2]) -> Vec> { From 1f352659b1f2da7e2dd8073a06f83ec51dfc1cd9 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 3 Jun 2026 14:22:48 +0000 Subject: [PATCH 18/18] Simplify generic-prefix scan and fold strip into the single parser pass Address two follow-ups on the review-fix commit: - The path-walking loop was over-careful: at this layer the only thing that matters is whether tokens contain a `<` whose args might host refinements; the prefix is opaque. Replace the loop with a single search for the first top-level `<`, which covers paths (`crate::Pair<..>`), unqualified names (`Vec<..>`), and `dyn` / `impl` prefixes without enumerating them. - The separate strip_refinement_braces was a second recursion over the same tokens. Fold it into parse_refined_type_annotations so each call returns the annotations it found together with the type expression rewritten as plain Rust (every refinement brace replaced by its binder type). Drop strip_refinement_braces. https://claude.ai/code/session_01Km9xwaVaGQjnAy1iHsPewa --- thrust-macros/src/rty.rs | 154 +++++++++++++++------------------------ 1 file changed, 58 insertions(+), 96 deletions(-) diff --git a/thrust-macros/src/rty.rs b/thrust-macros/src/rty.rs index f770e30d..2715c0d1 100644 --- a/thrust-macros/src/rty.rs +++ b/thrust-macros/src/rty.rs @@ -109,12 +109,12 @@ fn collect_param_annotations( ) -> syn::Result> { let (name, ty_tokens) = parse_name_typed_binding(attr_tokens)?; let idx = param_index(func, &name)?; - let annotations = parse_refined_type_annotations(&ty_tokens)?; + let (annotations, _) = parse_refined_type_annotations(&ty_tokens)?; Ok(anchor_at(annotations, TypePositionStep::Param(idx))) } fn collect_ret_annotations(attr_tokens: &[TokenTree2]) -> syn::Result> { - let annotations = parse_refined_type_annotations(attr_tokens)?; + let (annotations, _) = parse_refined_type_annotations(attr_tokens)?; Ok(anchor_at(annotations, TypePositionStep::Return)) } @@ -139,10 +139,8 @@ fn collect_sig_annotations( } let (name, ty_tokens) = parse_name_typed_binding(&arg)?; let idx = param_index(func, &name)?; - annotations.extend(anchor_at( - parse_refined_type_annotations(&ty_tokens)?, - TypePositionStep::Param(idx), - )); + let (param_annotations, _) = parse_refined_type_annotations(&ty_tokens)?; + annotations.extend(anchor_at(param_annotations, TypePositionStep::Param(idx))); } let rest = &attr_tokens[2..]; @@ -159,10 +157,8 @@ fn collect_sig_annotations( )) } }; - annotations.extend(anchor_at( - parse_refined_type_annotations(ret_tokens)?, - TypePositionStep::Return, - )); + let (ret_annotations, _) = parse_refined_type_annotations(ret_tokens)?; + annotations.extend(anchor_at(ret_annotations, TypePositionStep::Return)); Ok(annotations) } @@ -210,18 +206,20 @@ fn expand_with_annotations( .into() } -/// Returns every refined-type annotation contained in `tokens`, with each -/// annotation's position relative to the start of `tokens`. +/// Parses the type expression in `tokens` and returns both the refined-type +/// annotations it contains (positioned relative to the start of `tokens`) and +/// the type expression as plain Rust tokens (with every refinement brace +/// replaced by the brace's binder type). fn parse_refined_type_annotations( tokens: &[TokenTree2], -) -> syn::Result> { +) -> syn::Result<(Vec, Vec)> { if tokens.is_empty() { - return Ok(Vec::new()); + return Ok((Vec::new(), Vec::new())); } - // A refinement type is a brace-delimited group; nested annotations in its - // binder type sit at the same position. The stored `binder_ty` has nested - // refinement braces stripped so it is a valid Rust type when emitted. + // A refinement type `{ binder : ty | formula }`: emit one annotation here, + // then descend into the binder type for any nested annotations. The + // stripped form of the whole brace is the stripped form of the binder type. if tokens.len() == 1 { if let TokenTree2::Group(g) = &tokens[0] { if g.delimiter() == proc_macro2::Delimiter::Brace { @@ -232,26 +230,24 @@ fn parse_refined_type_annotations( .ok_or_else(|| err_tokens(&inner, "refinement type must contain `|`"))?; let (binder, binder_ty_tokens) = parse_name_typed_binding(&inner[..bar])?; let formula: TokenStream2 = inner[bar + 1..].iter().cloned().collect(); - let nested = parse_refined_type_annotations(&binder_ty_tokens)?; - let stripped_binder_ty: TokenStream2 = strip_refinement_braces(&binder_ty_tokens) - .into_iter() - .collect(); + let (nested, stripped_binder_ty) = + parse_refined_type_annotations(&binder_ty_tokens)?; let mut out = vec![RefinedTypeAnnotation { position: Vec::new(), refined_type: RefinedType { binder, - binder_ty: stripped_binder_ty, + binder_ty: stripped_binder_ty.iter().cloned().collect(), formula, }, }]; out.extend(nested); - return Ok(out); + return Ok((out, stripped_binder_ty)); } } } - // A reference `& [lifetime] [mut] T` is encoded by the plugin as a pointer; - // descend into the pointee at `TypeArg(0)`. + // A reference `& [lifetime] [mut] T`: encoded by the plugin as a pointer, + // so descend into the pointee at `TypeArg(0)`. if let Some(TokenTree2::Punct(p)) = tokens.first() { if p.as_char() == '&' { let mut cursor = 1; @@ -266,83 +262,49 @@ fn parse_refined_type_annotations( if matches!(tokens.get(cursor), Some(TokenTree2::Ident(id)) if id == "mut") { cursor += 1; } - return Ok(anchor_at( - parse_refined_type_annotations(&tokens[cursor..])?, - TypePositionStep::TypeArg(0), - )); + let (nested, stripped_pointee) = parse_refined_type_annotations(&tokens[cursor..])?; + let mut stripped = tokens[..cursor].to_vec(); + stripped.extend(stripped_pointee); + return Ok((anchor_at(nested, TypePositionStep::TypeArg(0)), stripped)); } } - // A path-qualified nominal type `path::to::Name`: walk past the path - // segments and recurse into each generic argument at `TypeArg(i)`. - let mut cursor = 0; - loop { - if !matches!(tokens.get(cursor), Some(TokenTree2::Ident(_))) { - return Ok(Vec::new()); - } - let next = cursor + 1; - match (tokens.get(next), tokens.get(next + 1)) { - (Some(TokenTree2::Punct(p1)), Some(TokenTree2::Punct(p2))) - if p1.as_char() == ':' && p2.as_char() == ':' => - { - cursor = next + 2; - } - (Some(TokenTree2::Punct(p)), _) if p.as_char() == '<' => { - let mut out = Vec::new(); - let mut type_idx = 0; - for arg in split_angle_args(&tokens[next + 1..]) { - if is_lifetime(&arg) { - continue; - } - out.extend(anchor_at( - parse_refined_type_annotations(&arg)?, - TypePositionStep::TypeArg(type_idx), - )); - type_idx += 1; - } - return Ok(out); - } - _ => return Ok(Vec::new()), + // Anything else: scan past the qualifier tokens (path segments, `dyn`, + // `impl`, …) to the first top-level `<`, then recurse into each generic + // argument at `TypeArg(i)`. With no `<` there's nothing to descend into. + let Some(lt) = tokens + .iter() + .position(|tt| matches!(tt, TokenTree2::Punct(p) if p.as_char() == '<')) + else { + return Ok((Vec::new(), tokens.to_vec())); + }; + let mut out_annotations = Vec::new(); + let mut stripped = tokens[..=lt].to_vec(); + let mut type_idx = 0; + for (i, arg) in split_angle_args(&tokens[lt + 1..]).into_iter().enumerate() { + if i > 0 { + stripped.push(TokenTree2::Punct(proc_macro2::Punct::new( + ',', + proc_macro2::Spacing::Alone, + ))); } - } -} - -/// Returns `tokens` with every refinement-type brace `{ b : ty | phi }` -/// replaced by its (recursively stripped) binder type `ty`, so the result is -/// a valid Rust type expression. -fn strip_refinement_braces(tokens: &[TokenTree2]) -> Vec { - let mut out = Vec::new(); - for tt in tokens { - match tt { - TokenTree2::Group(g) if g.delimiter() == proc_macro2::Delimiter::Brace => { - let inner: Vec = g.stream().into_iter().collect(); - let as_refinement = inner - .iter() - .position(|tt| matches!(tt, TokenTree2::Punct(p) if p.as_char() == '|')) - .and_then(|bar| parse_name_typed_binding(&inner[..bar]).ok()); - if let Some((_, ty_tokens)) = as_refinement { - out.extend(strip_refinement_braces(&ty_tokens)); - } else { - let stripped: TokenStream2 = - strip_refinement_braces(&inner).into_iter().collect(); - out.push(TokenTree2::Group(proc_macro2::Group::new( - g.delimiter(), - stripped, - ))); - } - } - TokenTree2::Group(g) => { - let inner: Vec = g.stream().into_iter().collect(); - let stripped: TokenStream2 = strip_refinement_braces(&inner).into_iter().collect(); - out.push(TokenTree2::Group(proc_macro2::Group::new( - g.delimiter(), - stripped, - ))); - } - other => out.push(other.clone()), + if is_lifetime(&arg) { + stripped.extend(arg); + continue; } + let (arg_annotations, stripped_arg) = parse_refined_type_annotations(&arg)?; + out_annotations.extend(anchor_at( + arg_annotations, + TypePositionStep::TypeArg(type_idx), + )); + stripped.extend(stripped_arg); + type_idx += 1; } - out + stripped.push(TokenTree2::Punct(proc_macro2::Punct::new( + '>', + proc_macro2::Spacing::Alone, + ))); + Ok((out_annotations, stripped)) } /// Prepends `root` to every annotation's position.