Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions crates/sprout-db/src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,34 @@ pub async fn get_members_bulk(pool: &PgPool, channel_ids: &[Uuid]) -> Result<Vec
rows.into_iter().map(row_to_member_record).collect()
}

/// Return the subset of `channel_ids` that refer to live, non-deleted channels.
///
/// This intentionally checks only channel existence/liveness, not membership. Callers use it
/// for relay-level public-readable policy after parsing the configured allowlist.
pub async fn filter_live_channel_ids(pool: &PgPool, channel_ids: &[Uuid]) -> Result<Vec<Uuid>> {
if channel_ids.is_empty() {
return Ok(Vec::new());
}

let rows = sqlx::query(
r#"
SELECT id
FROM channels
WHERE id = ANY($1) AND deleted_at IS NULL
"#,
)
.bind(channel_ids)
.fetch_all(pool)
.await?;

rows.into_iter()
.map(|r| {
let id: Uuid = r.try_get("id")?;
Ok(id)
})
.collect()
}

/// Get all channel IDs accessible to a pubkey.
///
/// Includes channels where the pubkey is an active member AND all open channels.
Expand Down Expand Up @@ -1203,6 +1231,46 @@ mod tests {
Keys::generate().public_key().to_bytes().to_vec()
}

#[tokio::test]
#[ignore = "requires Postgres"]
async fn test_filter_live_channel_ids_excludes_deleted_and_missing_channels() {
let pool = setup_pool().await;
let owner_pk = random_pubkey();
ensure_user(&pool, &owner_pk).await.expect("ensure owner");

let live = create_channel(
&pool,
&format!("test-public-live-{}", Uuid::new_v4()),
ChannelType::Stream,
ChannelVisibility::Private,
None,
&owner_pk,
None,
)
.await
.expect("create live channel");
let deleted = create_channel(
&pool,
&format!("test-public-deleted-{}", Uuid::new_v4()),
ChannelType::Stream,
ChannelVisibility::Private,
None,
&owner_pk,
None,
)
.await
.expect("create deleted channel");
soft_delete_channel(&pool, deleted.id)
.await
.expect("soft delete channel");

let result = filter_live_channel_ids(&pool, &[live.id, deleted.id, Uuid::new_v4()])
.await
.expect("filter live channel ids");

assert_eq!(result, vec![live.id]);
}

/// Agent owner (non-admin) can remove their own bot from a channel.
#[tokio::test]
#[ignore = "requires Postgres"]
Expand Down
62 changes: 44 additions & 18 deletions crates/sprout-db/src/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,25 +218,23 @@ pub async fn query_events(pool: &PgPool, q: &EventQuery) -> Result<Vec<StoredEve
qb.push(format!(" AND {col_prefix}channel_id IS NULL"));
}

// Multi-channel IN pushdown: restrict to events in any of these channels
// OR global events (channel_id IS NULL). Used by NIP-45 COUNT to enforce
// channel access at the SQL level without fetching all rows.
// Multi-channel IN pushdown: restrict to events in any of these channels.
// Used by NIP-45 COUNT global filters to enforce channel access at SQL level
// without broadening public-readable channel access to global events.
//
// SECURITY: Some(empty vec) means "user has access to NO channels" —
// only global events (channel_id IS NULL) should be returned.
// SECURITY: Some(empty vec) means "match no channel-scoped events" — return
// empty immediately. Callers that intentionally include global events should
// query them separately via `global_only`.
if let Some(ref ch_ids) = q.channel_ids {
if ch_ids.is_empty() {
// No channel access — only global (non-channel) events visible.
qb.push(format!(" AND {col_prefix}channel_id IS NULL"));
return Ok(vec![]);
} else {
qb.push(format!(
" AND ({col_prefix}channel_id IS NULL OR {col_prefix}channel_id IN ("
));
qb.push(format!(" AND {col_prefix}channel_id IN ("));
let mut sep = qb.separated(", ");
for ch in ch_ids {
sep.push_bind(*ch);
}
qb.push("))");
qb.push(")");
}
}

Expand Down Expand Up @@ -438,20 +436,19 @@ pub async fn count_events(pool: &PgPool, q: &EventQuery) -> Result<i64> {
qb.push(format!(" AND {col_prefix}channel_id IS NULL"));
}

// Multi-channel IN pushdown for COUNT: restrict to accessible channels + global.
// SECURITY: Some(empty vec) = no channel access → global events only.
// Multi-channel IN pushdown for COUNT: restrict to accessible channel-scoped
// events only. Global events require an explicit `global_only` query.
// SECURITY: Some(empty vec) = no channel access → count nothing.
if let Some(ref ch_ids) = q.channel_ids {
if ch_ids.is_empty() {
qb.push(format!(" AND {col_prefix}channel_id IS NULL"));
return Ok(0);
} else {
qb.push(format!(
" AND ({col_prefix}channel_id IS NULL OR {col_prefix}channel_id IN ("
));
qb.push(format!(" AND {col_prefix}channel_id IN ("));
let mut sep = qb.separated(", ");
for ch in ch_ids {
sep.push_bind(*ch);
}
qb.push("))");
qb.push(")");
}
}

Expand Down Expand Up @@ -1064,6 +1061,35 @@ mod tests {
assert_eq!(extract_d_tag(&above), None);
}

#[test]
fn channel_ids_pushdown_is_channel_scoped_not_global() {
let ch = uuid::Uuid::new_v4();
let q = EventQuery {
channel_ids: Some(vec![ch]),
..Default::default()
};

assert_eq!(q.channel_ids.as_deref(), Some(&[ch][..]));
assert!(
!q.global_only,
"channel_ids pushdown is channel-scoped only; global reads require explicit global_only"
);
}

#[test]
fn empty_channel_ids_pushdown_matches_no_channel_events() {
let q = EventQuery {
channel_ids: Some(vec![]),
..Default::default()
};

assert_eq!(q.channel_ids.as_deref(), Some(&[][..]));
assert!(
!q.global_only,
"an empty effective readable set must not silently broaden to global events"
);
}

#[test]
fn extract_d_tag_single_element_d_tag_ignored() {
// A d tag with only one element (no value) should not match — parts.len() < 2
Expand Down
16 changes: 15 additions & 1 deletion crates/sprout-db/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,11 @@ impl Db {
channel::get_accessible_channel_ids(&self.pool, pubkey).await
}

/// Return the subset of channel IDs that exist and have not been soft-deleted.
pub async fn filter_live_channel_ids(&self, channel_ids: &[Uuid]) -> Result<Vec<Uuid>> {
channel::filter_live_channel_ids(&self.pool, channel_ids).await
}

/// Lists channels, optionally filtered by visibility.
pub async fn list_channels(
&self,
Expand Down Expand Up @@ -695,8 +700,17 @@ impl Db {
depth_limit: Option<u32>,
limit: u32,
cursor: Option<&[u8]>,
channel_id: Option<Uuid>,
) -> Result<Vec<thread::ThreadReply>> {
thread::get_thread_replies(&self.pool, root_event_id, depth_limit, limit, cursor).await
thread::get_thread_replies(
&self.pool,
root_event_id,
depth_limit,
limit,
cursor,
channel_id,
)
.await
}

/// Fetch aggregated thread stats.
Expand Down
9 changes: 9 additions & 0 deletions crates/sprout-db/src/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,13 +316,15 @@ pub async fn decrement_reply_count(
/// - `cursor` -- if `Some(ts_bytes)`, returns replies with `event_created_at`
/// strictly after the timestamp encoded in `ts_bytes`. The bytes must be an
/// 8-byte big-endian i64 Unix timestamp in seconds.
/// - `channel_id` -- if `Some`, returns only replies in that channel.
/// - `limit` -- maximum rows returned (caller should cap this).
pub async fn get_thread_replies(
pool: &PgPool,
root_event_id: &[u8],
depth_limit: Option<u32>,
limit: u32,
cursor: Option<&[u8]>,
channel_id: Option<Uuid>,
) -> Result<Vec<ThreadReply>> {
// Decode cursor bytes -> DateTime<Utc> for the keyset condition.
let cursor_ts: Option<DateTime<Utc>> = match cursor {
Expand Down Expand Up @@ -367,6 +369,10 @@ pub async fn get_thread_replies(
sql.push_str(&format!(" AND tm.event_created_at > ${param_idx}"));
param_idx += 1;
}
if channel_id.is_some() {
sql.push_str(&format!(" AND tm.channel_id = ${param_idx}"));
param_idx += 1;
}

sql.push_str(&format!(
" ORDER BY tm.event_created_at ASC LIMIT ${param_idx}"
Expand All @@ -380,6 +386,9 @@ pub async fn get_thread_replies(
if let Some(ts) = cursor_ts {
q = q.bind(ts);
}
if let Some(ch_id) = channel_id {
q = q.bind(ch_id);
}
q = q.bind(limit as i32);

let rows = q.fetch_all(pool).await?;
Expand Down
58 changes: 39 additions & 19 deletions crates/sprout-relay/src/api/bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,9 @@ pub async fn query_events(
));
}

// Get channels this user can access — same enforcement as WS REQ handler.
// Effective authenticated read set = normal access ∪ live public-readable allowlist.
let accessible_channels = state
.get_accessible_channel_ids_cached(&pubkey_bytes)
.effective_readable_channel_ids(&pubkey_bytes, None)
.await
.map_err(|e| internal_error(&format!("channel access lookup: {e}")))?;

Expand Down Expand Up @@ -365,11 +365,17 @@ pub async fn query_events(
_ => continue,
};

if let Some(ch_id) = extract_channel_from_filter(filter) {
if !accessible_channels.contains(&ch_id) {
handled.insert(idx);
continue;
}
let Some(ch_id) = extract_channel_from_filter(filter) else {
handled.insert(idx);
continue;
};
// Bridge is authenticated-only today, so `accessible_channels` is already the
// effective read set. If unauthenticated/public bridge reads are added later,
// they must resolve the same live public allowlist before this access check;
// never fall back to unscoped thread expansion for `depth_limit`.
if !accessible_channels.contains(&ch_id) {
handled.insert(idx);
continue;
}

let limit = filter
Expand All @@ -378,7 +384,7 @@ pub async fn query_events(
.min(BRIDGE_THREAD_MAX_LIMIT as usize) as u32;
let thread_replies = state
.db
.get_thread_replies(&root_bytes, Some(depth), limit, None)
.get_thread_replies(&root_bytes, Some(depth), limit, None, Some(ch_id))
.await
.map_err(|e| internal_error(&format!("thread query error: {e}")))?;

Expand Down Expand Up @@ -415,9 +421,7 @@ pub async fn query_events(
}
}

let mut query =
crate::handlers::req::build_event_query_from_filter(filter, &pubkey_bytes, &state)
.await;
let mut query = crate::handlers::req::build_event_query_from_filter(filter);

if let Some(bid) = extract_before_id(raw) {
if query.until.is_none() {
Expand Down Expand Up @@ -495,9 +499,9 @@ pub async fn count_events(
));
}

// Get channels this user can access.
// Effective authenticated read set = normal access ∪ live public-readable allowlist.
let accessible_channels = state
.get_accessible_channel_ids_cached(&pubkey_bytes)
.effective_readable_channel_ids(&pubkey_bytes, None)
.await
.map_err(|e| internal_error(&format!("channel access lookup: {e}")))?;

Expand All @@ -509,9 +513,7 @@ pub async fn count_events(
continue; // Skip filters targeting inaccessible channels.
}
// Channel is accessible — count with pushability check.
let query =
crate::handlers::req::build_event_query_from_filter(filter, &pubkey_bytes, &state)
.await;
let query = crate::handlers::req::build_event_query_from_filter(filter);
if crate::handlers::req::filter_fully_pushable(filter) {
match state.db.count_events(&query).await {
Ok(n) => total += n as u64,
Expand Down Expand Up @@ -541,9 +543,7 @@ pub async fn count_events(
} else {
// No channel filter — use SQL-level channel_ids pushdown to count
// only events in accessible channels (+ global events).
let mut query =
crate::handlers::req::build_event_query_from_filter(filter, &pubkey_bytes, &state)
.await;
let mut query = crate::handlers::req::build_event_query_from_filter(filter);
query.channel_ids = Some(accessible_channels.to_vec());

if crate::handlers::req::filter_fully_pushable(filter) {
Expand Down Expand Up @@ -1182,6 +1182,26 @@ mod tests {
assert!(extract_feed_types(&raw).is_none());
}

#[test]
fn depth_limited_thread_filter_without_channel_is_not_channel_accessible() {
let filter = nostr::Filter::new().event(nostr::EventId::from_slice(&[1u8; 32]).unwrap());

assert!(
extract_channel_from_filter(&filter).is_none(),
"depth_limit bridge path must fail closed unless caller supplies #h"
);
}

#[test]
fn depth_limited_thread_filter_with_channel_extracts_channel() {
let ch = uuid::Uuid::new_v4();
let filter = nostr::Filter::new()
.event(nostr::EventId::from_slice(&[1u8; 32]).unwrap())
.custom_tag(SingleLetterTag::lowercase(Alphabet::H), ch.to_string());

assert_eq!(extract_channel_from_filter(&filter), Some(ch));
}

#[test]
fn event_accessible_no_channel() {
let keys = Keys::generate();
Expand Down
Loading