Skip to content

Commit 6071312

Browse files
marc-jasper-sonarsourcesonartech
authored andcommitted
SONARPY-3915 Improve SQLQueriesCheck with reaching definitions analysis (#960)
GitOrigin-RevId: 693473abdbaa10b0ee7648ee4e8f830159bd5480
1 parent debf131 commit 6071312

File tree

2 files changed

+35
-6
lines changed

2 files changed

+35
-6
lines changed

python-checks/src/main/java/org/sonar/python/checks/hotspots/SQLQueriesCheck.java

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ private void checkCallExpression(SubscriptionContext context) {
111111
}
112112

113113
private static void addIssue(SubscriptionContext context, CallExpression callExpression) {
114-
Optional<Tree> secondary = sensitiveArgumentValue(callExpression);
115-
secondary.ifPresent(tree -> context.addIssue(callExpression, MESSAGE).secondary(tree, null));
114+
Optional<Tree> secondary = sensitiveArgumentValue(callExpression, context);
115+
secondary.ifPresent(tree -> context.addIssue(callExpression, MESSAGE).secondary(tree, null));
116116
}
117117

118118
private static boolean isException(CallExpression callExpression, String functionName) {
@@ -123,7 +123,7 @@ private static boolean isException(CallExpression callExpression, String functio
123123
return argListNode.isEmpty();
124124
}
125125

126-
private static Optional<Tree> sensitiveArgumentValue(CallExpression callExpression) {
126+
private static Optional<Tree> sensitiveArgumentValue(CallExpression callExpression, SubscriptionContext ctx) {
127127
List<Argument> argListNode = callExpression.arguments();
128128
if (argListNode.isEmpty()) {
129129
return Optional.empty();
@@ -134,14 +134,27 @@ private static Optional<Tree> sensitiveArgumentValue(CallExpression callExpressi
134134
}
135135
Expression expression = getExpression(((RegularArgument) arg).expression());
136136
if (expression.is(Tree.Kind.NAME)) {
137-
expression = Expressions.singleAssignedValue((Name) expression);
137+
return findFormattedValue((Name) expression, ctx);
138138
}
139-
if (expression != null && isFormatted(expression)) {
139+
if (isFormatted(expression)) {
140140
return Optional.of(expression);
141141
}
142142
return Optional.empty();
143143
}
144144

145+
private static Optional<Tree> findFormattedValue(Name name, SubscriptionContext ctx) {
146+
Set<Expression> values = ctx.valuesAtLocation(name);
147+
if (!values.isEmpty()) {
148+
return values.stream()
149+
.filter(SQLQueriesCheck::isFormatted)
150+
.findFirst()
151+
.map(Tree.class::cast);
152+
}
153+
return Optional.ofNullable(Expressions.singleAssignedValue(name))
154+
.filter(SQLQueriesCheck::isFormatted)
155+
.map(Tree.class::cast);
156+
}
157+
145158
private static boolean isFormatted(Expression tree) {
146159
FormattedStringVisitor visitor = new FormattedStringVisitor();
147160
tree.accept(visitor);

python-checks/src/test/resources/checks/hotspots/sqlQuery.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def query_my_user(request, params):
3030
MyUser.objects.raw(formatted_request3) # Noncompliant
3131
MyUser.objects.raw(y := formatted_request3) # Noncompliant
3232
MyUser.objects.raw((y := formatted_request3)) # Noncompliant
33-
MyUser.objects.raw(formatted_request4) # FN, multiple assignments
33+
MyUser.objects.raw(formatted_request4) # Noncompliant
3434
MyUser.objects.raw(formatted_request5) # OK
3535
MyUser.objects.raw(*formatted_request5) # OK
3636

@@ -58,5 +58,21 @@ def query_my_user(request, params):
5858
MyUser.objects.extra({ 'mycol': "select col from sometable here mycol = %s and othercol = " + value}) # Noncompliant
5959
MyUser.objects.extra({ 'mycol': "select col from sometable here mycol = %s and othercol = " + ""}) # Noncompliant
6060

61+
def test_reaching_definitions(self, request, value):
62+
query = "SELECT 1"
63+
if request:
64+
query = 'SELECT * FROM mytable WHERE name = "%s"' % value
65+
MyUser.objects.raw(query) # Noncompliant
66+
67+
safe_query = "SELECT 1"
68+
if request:
69+
safe_query = "SELECT 2"
70+
MyUser.objects.raw(safe_query) # OK
71+
72+
all_formatted = 'SELECT * FROM mytable WHERE name = "%s"' % value
73+
if request:
74+
all_formatted = f'SELECT * FROM mytable WHERE name = "{value}"'
75+
MyUser.objects.raw(all_formatted) # Noncompliant
76+
6177
def fun():
6278
pass

0 commit comments

Comments
 (0)