Skip to content

Commit bea59a7

Browse files
aayush3011VenkataAnilKumar
authored andcommitted
Fix: Lazy initialization of _InferenceService to prevent startup crash with AAD credentials (#46243)
1 parent ba72801 commit bea59a7

6 files changed

Lines changed: 464 additions & 7 deletions

sdk/cosmos/azure-cosmos/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#### Breaking Changes
88

99
#### Bugs Fixed
10+
* Fixed bug where `CosmosClient` construction with AAD credentials would crash at startup if the semantic reranking inference endpoint environment variable was not set, even when semantic reranking was not being used. The inference service is now lazily initialized on first use. See [PR 46243](https://github.com/Azure/azure-sdk-for-python/pull/46243)
1011

1112
#### Other Changes
1213

sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"""
2626
import logging
2727
import os
28+
import threading
2829
import urllib.parse
2930
import uuid
3031
from concurrent.futures.thread import ThreadPoolExecutor
@@ -262,8 +263,7 @@ def __init__( # pylint: disable=too-many-statements
262263
)
263264

264265
self._inference_service: Optional[_InferenceService] = None
265-
if self.aad_credentials:
266-
self._inference_service = _InferenceService(self)
266+
self._inference_service_lock = threading.Lock()
267267

268268
# Query compatibility mode.
269269
# Allows to specify compatibility mode used by client when making query requests. Should be removed when
@@ -332,7 +332,18 @@ def _set_client_consistency_level(
332332
self.session = None
333333

334334
def _get_inference_service(self) -> Optional[_InferenceService]:
335-
"""Get inference service instance"""
335+
"""Get inference service instance, lazily initializing on first access."""
336+
if self._inference_service is None and self.aad_credentials:
337+
with self._inference_service_lock:
338+
if self._inference_service is None:
339+
try:
340+
self._inference_service = _InferenceService(self)
341+
except ValueError as e:
342+
raise exceptions.CosmosHttpResponseError(
343+
message=f"Failed to initialize inference service: {e}",
344+
response=None,
345+
status_code=http_constants.StatusCodes.BAD_REQUEST
346+
) from e
336347
return self._inference_service
337348

338349
@property

sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,6 @@ def __init__( # pylint: disable=too-many-statements
259259
)
260260

261261
self._inference_service: Optional[_InferenceService] = None
262-
if self.aad_credentials:
263-
self._inference_service = _InferenceService(self)
264262

265263
self._setup_kwargs: dict[str, Any] = kwargs
266264
self.session: Optional[_session.Session] = None
@@ -295,7 +293,16 @@ def _set_container_properties_cache(self, container_link: str, properties: Optio
295293
self.__container_properties_cache[container_link] = {}
296294

297295
def _get_inference_service(self) -> Optional[_InferenceService]:
298-
"""Get async inference service instance"""
296+
"""Get async inference service instance, lazily initializing on first access."""
297+
if self._inference_service is None and self.aad_credentials:
298+
try:
299+
self._inference_service = _InferenceService(self)
300+
except ValueError as e:
301+
raise exceptions.CosmosHttpResponseError(
302+
message=f"Failed to initialize inference service: {e}",
303+
response=None,
304+
status_code=http_constants.StatusCodes.BAD_REQUEST
305+
) from e
299306
return self._inference_service
300307

301308
@property
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# The MIT License (MIT)
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
4+
# cspell:ignore reranker
5+
"""Regression test for lazy initialization of _InferenceService with AAD credentials.
6+
7+
Verifies that constructing a CosmosClient with AAD credentials does not crash
8+
when AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT is not set.
9+
"""
10+
import base64
11+
import json
12+
import os
13+
import time
14+
import unittest
15+
from io import StringIO
16+
17+
import pytest
18+
from azure.core.credentials import AccessToken
19+
20+
import azure.cosmos.cosmos_client as cosmos_client
21+
import test_config
22+
from azure.cosmos import DatabaseProxy, ContainerProxy
23+
24+
25+
def _remove_padding(encoded_string):
26+
while encoded_string.endswith("="):
27+
encoded_string = encoded_string[0:len(encoded_string) - 1]
28+
return encoded_string
29+
30+
31+
def get_test_item(num):
32+
test_item = {
33+
'pk': 'pk',
34+
'id': 'LazyInit_' + str(num),
35+
'test_object': True,
36+
}
37+
return test_item
38+
39+
40+
class CosmosEmulatorCredential(object):
41+
"""Fake AAD credential for the Cosmos emulator."""
42+
def get_token(self, *scopes, **kwargs):
43+
aad_header_cosmos_emulator = "{\"typ\":\"JWT\",\"alg\":\"RS256\",\"x5t\":\"" \
44+
"CosmosEmulatorPrimaryMaster\",\"kid\":\"CosmosEmulatorPrimaryMaster\"}"
45+
aad_claim_cosmos_emulator_format = {
46+
"aud": "https://localhost.localhost",
47+
"iss": "https://sts.fake-issuer.net/7b1999a1-dfd7-440e-8204-00170979b984",
48+
"iat": int(time.time()), "nbf": int(time.time()),
49+
"exp": int(time.time() + 7200), "aio": "", "appid": "localhost",
50+
"appidacr": "1", "idp": "https://localhost:8081/",
51+
"oid": "96313034-4739-43cb-93cd-74193adbe5b6", "rh": "", "sub": "localhost",
52+
"tid": "EmulatorFederation", "uti": "", "ver": "1.0",
53+
"scp": "user_impersonation",
54+
"groups": ["7ce1d003-4cb3-4879-b7c5-74062a35c66e",
55+
"e99ff30c-c229-4c67-ab29-30a6aebc3e58",
56+
"5549bb62-c77b-4305-bda9-9ec66b85d9e4",
57+
"c44fd685-5c58-452c-aaf7-13ce75184f65",
58+
"be895215-eab5-43b7-9536-9ef8fe130330"]}
59+
emulator_key = test_config.TestConfig.masterKey
60+
first = _remove_padding(str(base64.urlsafe_b64encode(aad_header_cosmos_emulator.encode("utf-8")), "utf-8"))
61+
str_io_obj = StringIO()
62+
json.dump(aad_claim_cosmos_emulator_format, str_io_obj)
63+
second = _remove_padding(
64+
str(base64.urlsafe_b64encode(str(str_io_obj.getvalue()).replace(" ", "").encode("utf-8")), "utf-8"))
65+
third = _remove_padding(str(base64.urlsafe_b64encode(emulator_key.encode("utf-8")), "utf-8"))
66+
return AccessToken(first + "." + second + "." + third, int(time.time() + 7200))
67+
68+
69+
@pytest.mark.cosmosEmulator
70+
@pytest.mark.cosmosLong
71+
class TestAADInferenceServiceLazyInit(unittest.TestCase):
72+
"""Verify AAD client construction succeeds without the semantic reranker env var.
73+
74+
Before the fix, _InferenceService was eagerly initialized during CosmosClient
75+
construction whenever AAD credentials were used, causing a ValueError if
76+
AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT was missing.
77+
"""
78+
79+
client: cosmos_client.CosmosClient = None
80+
database: DatabaseProxy = None
81+
container: ContainerProxy = None
82+
configs = test_config.TestConfig
83+
host = configs.host
84+
masterKey = configs.masterKey
85+
credential = CosmosEmulatorCredential() if configs.is_emulator else configs.credential
86+
87+
def setUp(self):
88+
"""Save the env var state before each test."""
89+
self._saved_endpoint = os.environ.get("AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT")
90+
91+
def tearDown(self):
92+
"""Ensure the env var is always unset after each test to prevent leakage."""
93+
os.environ.pop("AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT", None)
94+
95+
def test_aad_client_construction_without_inference_endpoint(self):
96+
"""Constructing a CosmosClient with AAD creds must not raise when
97+
AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT is unset."""
98+
os.environ.pop("AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT", None)
99+
100+
client = cosmos_client.CosmosClient(self.host, self.credential)
101+
db = client.get_database_client(self.configs.TEST_DATABASE_ID)
102+
container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID)
103+
104+
item = get_test_item(0)
105+
container.create_item(item)
106+
read_result = container.read_item(item=item['id'], partition_key='pk')
107+
assert read_result['id'] == item['id']
108+
109+
query_results = list(container.query_items(
110+
query='SELECT * FROM c WHERE c.id=@id',
111+
parameters=[{"name": "@id", "value": item['id']}],
112+
partition_key='pk'
113+
))
114+
assert len(query_results) == 1
115+
116+
container.delete_item(item=item['id'], partition_key='pk')
117+
118+
def test_aad_client_construction_with_inference_endpoint(self):
119+
"""Constructing a CosmosClient with AAD creds must also work when
120+
AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT IS set."""
121+
os.environ["AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT"] = "https://example.com"
122+
123+
client = cosmos_client.CosmosClient(self.host, self.credential)
124+
db = client.get_database_client(self.configs.TEST_DATABASE_ID)
125+
container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID)
126+
127+
item = get_test_item(1)
128+
container.create_item(item)
129+
read_result = container.read_item(item=item['id'], partition_key='pk')
130+
assert read_result['id'] == item['id']
131+
container.delete_item(item=item['id'], partition_key='pk')
132+
133+
134+
if __name__ == "__main__":
135+
unittest.main()
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# The MIT License (MIT)
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
4+
# cspell:ignore reranker
5+
"""Async regression test for lazy initialization of _InferenceService with AAD credentials.
6+
7+
Verifies that constructing an async CosmosClient with AAD credentials does not crash
8+
when AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT is not set.
9+
"""
10+
import base64
11+
import json
12+
import os
13+
import time
14+
import unittest
15+
from io import StringIO
16+
17+
import pytest
18+
from azure.core.credentials import AccessToken
19+
20+
import test_config
21+
from azure.cosmos.aio import CosmosClient, DatabaseProxy, ContainerProxy
22+
23+
24+
def _remove_padding(encoded_string):
25+
while encoded_string.endswith("="):
26+
encoded_string = encoded_string[0:len(encoded_string) - 1]
27+
return encoded_string
28+
29+
30+
def get_test_item(num):
31+
test_item = {
32+
'pk': 'pk',
33+
'id': 'LazyInitAsync_' + str(num),
34+
'test_object': True,
35+
}
36+
return test_item
37+
38+
39+
class CosmosEmulatorCredential(object):
40+
"""Fake async AAD credential for the Cosmos emulator."""
41+
async def get_token(self, *scopes, **kwargs):
42+
aad_header_cosmos_emulator = "{\"typ\":\"JWT\",\"alg\":\"RS256\",\"x5t\":\"" \
43+
"CosmosEmulatorPrimaryMaster\",\"kid\":\"CosmosEmulatorPrimaryMaster\"}"
44+
aad_claim_cosmos_emulator_format = {
45+
"aud": "https://localhost.localhost",
46+
"iss": "https://sts.fake-issuer.net/7b1999a1-dfd7-440e-8204-00170979b984",
47+
"iat": int(time.time()), "nbf": int(time.time()),
48+
"exp": int(time.time() + 7200), "aio": "", "appid": "localhost",
49+
"appidacr": "1", "idp": "https://localhost:8081/",
50+
"oid": "96313034-4739-43cb-93cd-74193adbe5b6", "rh": "", "sub": "localhost",
51+
"tid": "EmulatorFederation", "uti": "", "ver": "1.0",
52+
"scp": "user_impersonation",
53+
"groups": ["7ce1d003-4cb3-4879-b7c5-74062a35c66e",
54+
"e99ff30c-c229-4c67-ab29-30a6aebc3e58",
55+
"5549bb62-c77b-4305-bda9-9ec66b85d9e4",
56+
"c44fd685-5c58-452c-aaf7-13ce75184f65",
57+
"be895215-eab5-43b7-9536-9ef8fe130330"]}
58+
emulator_key = test_config.TestConfig.masterKey
59+
first = _remove_padding(str(base64.urlsafe_b64encode(aad_header_cosmos_emulator.encode("utf-8")), "utf-8"))
60+
str_io_obj = StringIO()
61+
json.dump(aad_claim_cosmos_emulator_format, str_io_obj)
62+
second = _remove_padding(
63+
str(base64.urlsafe_b64encode(str(str_io_obj.getvalue()).replace(" ", "").encode("utf-8")), "utf-8"))
64+
third = _remove_padding(str(base64.urlsafe_b64encode(emulator_key.encode("utf-8")), "utf-8"))
65+
return AccessToken(first + "." + second + "." + third, int(time.time() + 7200))
66+
67+
68+
@pytest.mark.cosmosEmulator
69+
@pytest.mark.cosmosLong
70+
class TestAADInferenceServiceLazyInitAsync(unittest.IsolatedAsyncioTestCase):
71+
"""Verify async AAD client construction succeeds without the semantic reranker env var.
72+
73+
Before the fix, _InferenceService was eagerly initialized during CosmosClient
74+
construction whenever AAD credentials were used, causing a ValueError if
75+
AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT was missing.
76+
"""
77+
78+
client: CosmosClient = None
79+
database: DatabaseProxy = None
80+
container: ContainerProxy = None
81+
configs = test_config.TestConfig
82+
host = configs.host
83+
masterKey = configs.masterKey
84+
credential = CosmosEmulatorCredential() if configs.is_emulator else configs.credential_async
85+
86+
def setUp(self):
87+
"""Save the env var state before each test."""
88+
self._saved_endpoint = os.environ.get("AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT")
89+
90+
def tearDown(self):
91+
"""Ensure the env var is always unset after each test to prevent leakage."""
92+
os.environ.pop("AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT", None)
93+
94+
async def test_aad_client_construction_without_inference_endpoint(self):
95+
"""Constructing an async CosmosClient with AAD creds must not raise when
96+
AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT is unset."""
97+
os.environ.pop("AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT", None)
98+
99+
client = CosmosClient(self.host, self.credential)
100+
db = client.get_database_client(self.configs.TEST_DATABASE_ID)
101+
container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID)
102+
103+
item = get_test_item(0)
104+
await container.create_item(item)
105+
read_result = await container.read_item(item=item['id'], partition_key='pk')
106+
assert read_result['id'] == item['id']
107+
108+
query_results = [
109+
doc async for doc in container.query_items(
110+
query='SELECT * FROM c WHERE c.id=@id',
111+
parameters=[{"name": "@id", "value": item['id']}],
112+
partition_key='pk'
113+
)
114+
]
115+
assert len(query_results) == 1
116+
117+
await container.delete_item(item=item['id'], partition_key='pk')
118+
await client.close()
119+
120+
async def test_aad_client_construction_with_inference_endpoint(self):
121+
"""Constructing an async CosmosClient with AAD creds must also work when
122+
AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT IS set."""
123+
os.environ["AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT"] = "https://placeholder.example.com"
124+
125+
client = CosmosClient(self.host, self.credential)
126+
db = client.get_database_client(self.configs.TEST_DATABASE_ID)
127+
container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID)
128+
129+
item = get_test_item(1)
130+
await container.create_item(item)
131+
read_result = await container.read_item(item=item['id'], partition_key='pk')
132+
assert read_result['id'] == item['id']
133+
await container.delete_item(item=item['id'], partition_key='pk')
134+
await client.close()
135+
136+
137+
if __name__ == "__main__":
138+
unittest.main()

0 commit comments

Comments
 (0)