@@ -7,12 +7,26 @@ import type { SQL } from "pg-sql2";
77import type { DataForType } from "graphile-build" ;
88import isSafeInteger from "lodash/isSafeInteger" ;
99import assert from "assert" ;
10+ import { inspect } from "util" ;
1011
1112// eslint-disable-next-line flowtype/no-weak-types
1213type GraphQLContext = any ;
1314
1415const identity = _ => _ !== null && _ !== undefined ;
1516
17+ function assertSafeName ( name : mixed ) {
18+ if ( typeof name !== "string" ) {
19+ throw new Error (
20+ `Expected name to be a string; instead received '${ inspect ( name ) } '`
21+ ) ;
22+ }
23+ if ( ! / ^ [ @ a - z A - Z 0 - 9 _ ] { 1 , 63 } $ / . test ( name ) ) {
24+ throw new Error (
25+ `Name '${ name } ' is not safe - either it is too long, too short, or has invalid characters`
26+ ) ;
27+ }
28+ }
29+
1630export default ( queryBuilderOptions : QueryBuilderOptions = { } ) => (
1731 from : SQL ,
1832 fromAlias : ?SQL ,
@@ -35,13 +49,34 @@ export default (queryBuilderOptions: QueryBuilderOptions = {}) => (
3549) => {
3650 const {
3751 pgQuery,
38- pgAggregateQuery,
52+ pgAggregateQuery, // Shorthand for using pgNamedQueryContainer/pgNamedQuery combo
53+ pgNamedQueryContainer = [ ] ,
54+ pgNamedQuery = [ ] ,
3955 pgCursorPrefix : reallyRawCursorPrefix ,
4056 pgDontUseAsterisk,
4157 calculateHasNextPage,
4258 calculateHasPreviousPage,
4359 usesCursor : explicitlyUsesCursor ,
4460 } = resolveData ;
61+ // Push a query container for aggregates
62+ if ( ( pgAggregateQuery && pgAggregateQuery . length ) || pgNamedQuery . length ) {
63+ pgNamedQueryContainer . push ( {
64+ name : "aggregates" ,
65+ query : ( { queryBuilder, options, innerQueryBuilder } ) => sql . fragment `\
66+ (
67+ select ${ innerQueryBuilder . build ( { onlyJsonField : true } ) }
68+ from ${ queryBuilder . getTableExpression ( ) } as ${ queryBuilder . getTableAlias ( ) }
69+ where ${ queryBuilder . buildWhereClause ( false , false , options ) }
70+ )` ,
71+ } ) ;
72+ }
73+ // Convert pgAggregateQuery to pgNamedQueryContainer/pgNamedQuery combo
74+ if ( pgAggregateQuery && pgAggregateQuery . length ) {
75+ // And a query for each previous query
76+ pgAggregateQuery . forEach ( query => {
77+ pgNamedQuery . push ( { name : "aggregates" , query } ) ;
78+ } ) ;
79+ }
4580
4681 const preventAsterisk = pgDontUseAsterisk
4782 ? pgDontUseAsterisk . length > 0
@@ -450,30 +485,57 @@ OR\
450485 fields . push ( [ hasPreviousPage , "hasPreviousPage" ] ) ;
451486 }
452487 }
453- if ( pgAggregateQuery && pgAggregateQuery . length ) {
454- const aggregateQueryBuilder = new QueryBuilder (
455- queryBuilderOptions ,
456- context ,
457- rootValue
458- ) ;
459- aggregateQueryBuilder . from (
460- queryBuilder . getTableExpression ( ) ,
461- queryBuilder . getTableAlias ( )
462- ) ;
488+ if ( pgNamedQuery && pgNamedQuery . length ) {
489+ const groups = { } ;
490+ pgNamedQuery . forEach ( ( { name, query } ) => {
491+ assertSafeName ( name ) ;
492+ if ( ! groups [ name ] ) {
493+ groups [ name ] = [ ] ;
494+ }
495+ groups [ name ] . push ( query ) ;
496+ } ) ;
497+ Object . keys ( groups ) . forEach ( groupName => {
498+ const queryCallbacks = groups [ groupName ] ;
463499
464- for ( let i = 0 , l = pgAggregateQuery . length ; i < l ; i ++ ) {
465- pgAggregateQuery [ i ] ( aggregateQueryBuilder ) ;
466- }
467- const aggregateJsonBuildObject = aggregateQueryBuilder . build ( {
468- onlyJsonField : true ,
500+ // Get container
501+ const containers = pgNamedQueryContainer . filter (
502+ c => c . name === groupName
503+ ) ;
504+ if ( containers . length === 0 ) {
505+ throw new Error (
506+ `${ queryCallbacks . length } pgNamedQuery entries with name: '${ groupName } ' existed, but there was no matching pgNamedQueryContainer.`
507+ ) ;
508+ }
509+ if ( containers . length > 1 ) {
510+ throw new Error (
511+ `${ containers . length } pgNamedQueryContainer entries with name: '${ groupName } ' existed, but there should be exactly one.`
512+ ) ;
513+ }
514+ const container = containers [ 0 ] ;
515+
516+ const innerQueryBuilder = new QueryBuilder (
517+ queryBuilderOptions ,
518+ context ,
519+ rootValue
520+ ) ;
521+ innerQueryBuilder . from (
522+ queryBuilder . getTableExpression ( ) ,
523+ queryBuilder . getTableAlias ( )
524+ ) ;
525+
526+ for ( let i = 0 , l = queryCallbacks . length ; i < l ; i ++ ) {
527+ queryCallbacks [ i ] ( innerQueryBuilder ) ;
528+ }
529+
530+ // Generate the SQL statement (e.g. `select ${innerQueryBuilder.build({onlyJsonField: true})} from ${queryBuilder.getTableExpression()} as ...`)
531+ const aggregatesSql = container . query ( {
532+ queryBuilder,
533+ innerQueryBuilder,
534+ options,
535+ } ) ;
536+
537+ fields . push ( [ aggregatesSql , groupName ] ) ;
469538 } ) ;
470- const aggregatesSql = sql . fragment `\
471- (
472- select ${ aggregateJsonBuildObject }
473- from ${ queryBuilder . getTableExpression ( ) } as ${ queryBuilder . getTableAlias ( ) }
474- where ${ queryBuilder . buildWhereClause ( false , false , options ) }
475- )` ;
476- fields . push ( [ aggregatesSql , "aggregates" ] ) ;
477539 }
478540 if ( options . withPaginationAsFields ) {
479541 return sql . fragment `${ sqlWith } select ${ sql . join (
0 commit comments