Skip to content

Commit f40d778

Browse files
committed
chore: Use context.Context from InvocationContext
1 parent aa82547 commit f40d778

File tree

21 files changed

+202
-135
lines changed

21 files changed

+202
-135
lines changed

internal/api/api.go

Lines changed: 57 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package api
22

33
import (
4+
"context"
45
"encoding/json"
56
"fmt"
67
"io"
@@ -14,15 +15,15 @@ import (
1415
//go:generate go tool github.com/golang/mock/mockgen -source=api.go -destination ../mocks/api.go -package mocks -self_package github.com/snyk/go-application-framework/pkg/api/
1516

1617
type ApiClient interface {
17-
GetDefaultOrgId() (orgID string, err error)
18-
GetOrgIdFromSlug(slugName string) (string, error)
19-
GetSlugFromOrgId(orgID string) (string, error)
20-
GetOrganizations(limit int) (*contract.OrganizationsResponse, error)
18+
GetDefaultOrgId(ctx context.Context) (orgID string, err error)
19+
GetOrgIdFromSlug(ctx context.Context, slugName string) (string, error)
20+
GetSlugFromOrgId(ctx context.Context, orgID string) (string, error)
21+
GetOrganizations(ctx context.Context, limit int) (*contract.OrganizationsResponse, error)
2122
Init(url string, client *http.Client)
22-
GetFeatureFlag(flagname string, origId string) (bool, error)
23-
GetUserMe() (string, error)
24-
GetSelf() (contract.SelfResponse, error)
25-
GetOrgSettings(orgId string) (*contract.OrgSettingsResponse, error)
23+
GetFeatureFlag(ctx context.Context, flagname string, origId string) (bool, error)
24+
GetUserMe(ctx context.Context) (string, error)
25+
GetSelf(ctx context.Context) (contract.SelfResponse, error)
26+
GetOrgSettings(ctx context.Context, orgId string) (*contract.OrgSettingsResponse, error)
2627
}
2728

2829
var _ ApiClient = (*snykApiClient)(nil)
@@ -35,17 +36,18 @@ type snykApiClient struct {
3536
// GetSlugFromOrgId retrieves the organization slug associated with a given Snyk organization ID.
3637
//
3738
// Parameters:
39+
// - ctx: Context for cancellation and timeout control.
3840
// - orgID (string): The UUID of the organization.
3941
//
4042
// Returns:
4143
// - The organization slug as a string.
4244
// - An error object (if the organization is not found, or if API request or response
4345
// parsing errors occur).
44-
func (a *snykApiClient) GetSlugFromOrgId(orgID string) (string, error) {
46+
func (a *snykApiClient) GetSlugFromOrgId(ctx context.Context, orgID string) (string, error) {
4547
endpoint := "/rest/orgs/" + orgID
4648
version := "2024-03-12"
4749

48-
body, err := clientGet(a, endpoint, &version)
50+
body, err := clientGet(ctx, a, endpoint, &version)
4951
if err != nil {
5052
return "", err
5153
}
@@ -62,17 +64,18 @@ func (a *snykApiClient) GetSlugFromOrgId(orgID string) (string, error) {
6264
// GetOrgIdFromSlug retrieves the organization ID associated with a given Snyk organization slug.
6365
//
6466
// Parameters:
67+
// - ctx: Context for cancellation and timeout control.
6568
// - slugName (string): The unique slug identifier of the organization.
6669
//
6770
// Returns:
6871
// - The organization ID as a string.
6972
// - An error object (if the organization is not found, or if API request or response
7073
// parsing errors occur).
71-
func (a *snykApiClient) GetOrgIdFromSlug(slugName string) (string, error) {
74+
func (a *snykApiClient) GetOrgIdFromSlug(ctx context.Context, slugName string) (string, error) {
7275
endpoint := "/rest/orgs"
7376
version := "2024-03-12"
7477

75-
body, err := clientGet(a, endpoint, &version, "slug", slugName)
78+
body, err := clientGet(ctx, a, endpoint, &version, "slug", slugName)
7679
if err != nil {
7780
return "", err
7881
}
@@ -95,16 +98,17 @@ func (a *snykApiClient) GetOrgIdFromSlug(slugName string) (string, error) {
9598
// GetOrganizations retrieves organizations accessible to the authenticated user.
9699
//
97100
// Parameters:
101+
// - ctx: Context for cancellation and timeout control.
98102
// - limit: Maximum number of organizations to return
99103
//
100104
// Returns:
101105
// - A pointer to OrganizationsResponse containing organizations.
102106
// - An error object (if an error occurred during the API request or response parsing).
103-
func (a *snykApiClient) GetOrganizations(limit int) (*contract.OrganizationsResponse, error) {
107+
func (a *snykApiClient) GetOrganizations(ctx context.Context, limit int) (*contract.OrganizationsResponse, error) {
104108
endpoint := "/rest/orgs"
105109
version := "2024-10-15"
106110

107-
body, err := clientGet(a, endpoint, &version, "limit", fmt.Sprintf("%d", limit))
111+
body, err := clientGet(ctx, a, endpoint, &version, "limit", fmt.Sprintf("%d", limit))
108112
if err != nil {
109113
return nil, err
110114
}
@@ -120,11 +124,14 @@ func (a *snykApiClient) GetOrganizations(limit int) (*contract.OrganizationsResp
120124

121125
// GetDefaultOrgId retrieves the default organization ID associated with the authenticated user.
122126
//
127+
// Parameters:
128+
// - ctx: Context for cancellation and timeout control.
129+
//
123130
// Returns:
124131
// - The user's default organization ID as a string.
125132
// - An error object (if an error occurred while fetching user data).
126-
func (a *snykApiClient) GetDefaultOrgId() (string, error) {
127-
selfData, err := a.GetSelf()
133+
func (a *snykApiClient) GetDefaultOrgId(ctx context.Context) (string, error) {
134+
selfData, err := a.GetSelf(ctx)
128135
if err != nil {
129136
return "", fmt.Errorf("unable to retrieve org ID: %w", err)
130137
}
@@ -134,11 +141,14 @@ func (a *snykApiClient) GetDefaultOrgId() (string, error) {
134141

135142
// GetUserMe retrieves the username for the authenticated user from the Snyk API.
136143
//
144+
// Parameters:
145+
// - ctx: Context for cancellation and timeout control.
146+
//
137147
// Returns:
138148
// - The authenticated user's username as a string.
139149
// - An error object (if an error occurred while fetching user data or extracting the username).
140-
func (a *snykApiClient) GetUserMe() (string, error) {
141-
selfData, err := a.GetSelf()
150+
func (a *snykApiClient) GetUserMe(ctx context.Context) (string, error) {
151+
selfData, err := a.GetSelf(ctx)
142152
if err != nil {
143153
return "", fmt.Errorf("error while fetching self data: %w", err) // Prioritize error
144154
}
@@ -161,14 +171,15 @@ func (a *snykApiClient) GetUserMe() (string, error) {
161171
// GetFeatureFlag determines the state of a feature flag for the specified organization.
162172
//
163173
// Parameters:
174+
// - ctx: Context for cancellation and timeout control.
164175
// - flagname (string): The name of the feature flag to check.
165176
// - orgId (string): The ID of the organization associated with the feature flag.
166177
//
167178
// Returns:
168179
// - A boolean indicating if the feature flag is enabled (true) or disabled (false).
169180
// - An error object (if an error occurred during the API request, response parsing,
170181
// or if the organization ID is invalid).
171-
func (a *snykApiClient) GetFeatureFlag(flagname string, orgId string) (bool, error) {
182+
func (a *snykApiClient) GetFeatureFlag(ctx context.Context, flagname string, orgId string) (bool, error) {
172183
const defaultResult = false
173184

174185
u := a.url + "/v1/cli-config/feature-flags/" + flagname + "?org=" + orgId
@@ -177,7 +188,12 @@ func (a *snykApiClient) GetFeatureFlag(flagname string, orgId string) (bool, err
177188
return defaultResult, fmt.Errorf("failed to lookup feature flag with orgiId not set")
178189
}
179190

180-
res, err := a.client.Get(u)
191+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
192+
if err != nil {
193+
return defaultResult, fmt.Errorf("unable to create request: %w", err)
194+
}
195+
196+
res, err := a.client.Do(req)
181197
if err != nil {
182198
return defaultResult, fmt.Errorf("unable to retrieve feature flag: %w", err)
183199
}
@@ -204,14 +220,17 @@ func (a *snykApiClient) GetFeatureFlag(flagname string, orgId string) (bool, err
204220

205221
// GetSelf retrieves the authenticated user's information from the Snyk API.
206222
//
223+
// Parameters:
224+
// - ctx: Context for cancellation and timeout control.
225+
//
207226
// Returns:
208227
// - A `contract.SelfResponse` struct containing the user's data.
209228
// - An error object (if an error occurred during the API request or response parsing).
210-
func (a *snykApiClient) GetSelf() (contract.SelfResponse, error) {
229+
func (a *snykApiClient) GetSelf(ctx context.Context) (contract.SelfResponse, error) {
211230
endpoint := "/rest/self"
212231
var selfData contract.SelfResponse
213232

214-
body, err := clientGet(a, endpoint, nil)
233+
body, err := clientGet(ctx, a, endpoint, nil)
215234
if err != nil {
216235
return selfData, err
217236
}
@@ -223,10 +242,15 @@ func (a *snykApiClient) GetSelf() (contract.SelfResponse, error) {
223242
return selfData, nil
224243
}
225244

226-
func (a *snykApiClient) GetOrgSettings(orgId string) (*contract.OrgSettingsResponse, error) {
245+
func (a *snykApiClient) GetOrgSettings(ctx context.Context, orgId string) (*contract.OrgSettingsResponse, error) {
227246
endpoint := fmt.Sprintf("%s/v1/org/%s/settings", a.url, url.QueryEscape(orgId))
228247

229-
res, err := a.client.Get(endpoint)
248+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
249+
if err != nil {
250+
return nil, fmt.Errorf("unable to create request: %w", err)
251+
}
252+
253+
res, err := a.client.Do(req)
230254
if err != nil {
231255
return nil, fmt.Errorf("unable to retrieve org settings: %w", err)
232256
}
@@ -250,6 +274,7 @@ func (a *snykApiClient) GetOrgSettings(orgId string) (*contract.OrgSettingsRespo
250274
// API versioning, and basic error checking.
251275
//
252276
// Parameters:
277+
// - ctx: Context for cancellation and timeout control.
253278
// - a (snykApiClient): A reference to the Snyk API client object.
254279
// - endpoint (string): The endpoint path to be appended to the API base URL.
255280
// - version (*string): An optional pointer to a string specifying the desired API version.
@@ -264,8 +289,8 @@ func (a *snykApiClient) GetOrgSettings(orgId string) (*contract.OrgSettingsRespo
264289
//
265290
// Example:
266291
// apiVersion := "2022-01-12"
267-
// response, err := clientGet(myApiClient, "/organizations", &apiVersion, "limit", "50")
268-
func clientGet(a *snykApiClient, endpoint string, version *string, queryParams ...string) ([]byte, error) {
292+
// response, err := clientGet(ctx, myApiClient, "/organizations", &apiVersion, "limit", "50")
293+
func clientGet(ctx context.Context, a *snykApiClient, endpoint string, version *string, queryParams ...string) ([]byte, error) {
269294
var apiVersion string = constants.SNYK_DEFAULT_API_VERSION
270295
if version != nil && *version != "" {
271296
apiVersion = *version
@@ -277,7 +302,12 @@ func clientGet(a *snykApiClient, endpoint string, version *string, queryParams .
277302
return nil, err
278303
}
279304

280-
res, err := a.client.Get(url.String())
305+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
306+
if err != nil {
307+
return nil, err
308+
}
309+
310+
res, err := a.client.Do(req)
281311
if err != nil {
282312
return nil, err
283313
}

internal/api/api_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func Test_GetDefaultOrgId_ReturnsCorrectOrgId(t *testing.T) {
2525
client := api.NewApi(server.URL, http.DefaultClient)
2626

2727
// Act
28-
orgId, err := client.GetDefaultOrgId()
28+
orgId, err := client.GetDefaultOrgId(t.Context())
2929
if err != nil {
3030
t.Error(err)
3131
}
@@ -46,7 +46,7 @@ func Test_GetSlugFromOrgId_ReturnsCorrectSlug(t *testing.T) {
4646
client := api.NewApi(server.URL, http.DefaultClient)
4747

4848
// Act
49-
actualSlug, err := client.GetSlugFromOrgId(orgID)
49+
actualSlug, err := client.GetSlugFromOrgId(t.Context(), orgID)
5050
if err != nil {
5151
t.Error(err)
5252
}
@@ -66,7 +66,7 @@ func Test_GetOrganizations_ReturnsOrganizations(t *testing.T) {
6666
client := api.NewApi(server.URL, http.DefaultClient)
6767

6868
// Act
69-
response, err := client.GetOrganizations(limit)
69+
response, err := client.GetOrganizations(t.Context(), limit)
7070
if err != nil {
7171
t.Error(err)
7272
}
@@ -93,7 +93,7 @@ func Test_GetOrgIdFromSlug_ReturnsCorrectOrgId(t *testing.T) {
9393
apiClient := api.NewApi(server.URL, http.DefaultClient)
9494

9595
// Act
96-
orgId, err := apiClient.GetOrgIdFromSlug(slugName)
96+
orgId, err := apiClient.GetOrgIdFromSlug(t.Context(), slugName)
9797
if err != nil {
9898
t.Error(err)
9999
}
@@ -115,11 +115,11 @@ func Test_GetFeatureFlag_false(t *testing.T) {
115115
server := setupSingleReponseServer(t, "/v1/cli-config/feature-flags/"+featureFlagName+"?org="+org, featureFlagResponse)
116116
client := api.NewApi(server.URL, http.DefaultClient)
117117

118-
actual, err := client.GetFeatureFlag(featureFlagName, org)
118+
actual, err := client.GetFeatureFlag(t.Context(), featureFlagName, org)
119119
assert.NoError(t, err)
120120
assert.False(t, actual)
121121

122-
actual, err = client.GetFeatureFlag("unknownFF", org)
122+
actual, err = client.GetFeatureFlag(t.Context(), "unknownFF", org)
123123
assert.Error(t, err)
124124
assert.False(t, actual)
125125
}
@@ -137,7 +137,7 @@ func Test_GetFeatureFlag_true(t *testing.T) {
137137
server := setupSingleReponseServer(t, "/v1/cli-config/feature-flags/"+featureFlagName+"?org="+org, featureFlagResponse)
138138
client := api.NewApi(server.URL, http.DefaultClient)
139139

140-
actual, err := client.GetFeatureFlag(featureFlagName, org)
140+
actual, err := client.GetFeatureFlag(t.Context(), featureFlagName, org)
141141
assert.NoError(t, err)
142142
assert.True(t, actual)
143143
}

0 commit comments

Comments
 (0)