Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 57 additions & 27 deletions internal/api/api.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package api

import (
"context"
"encoding/json"
"fmt"
"io"
Expand All @@ -14,15 +15,15 @@ import (
//go:generate go tool github.com/golang/mock/mockgen -source=api.go -destination ../mocks/api.go -package mocks -self_package github.com/snyk/go-application-framework/pkg/api/

type ApiClient interface {
GetDefaultOrgId() (orgID string, err error)
GetOrgIdFromSlug(slugName string) (string, error)
GetSlugFromOrgId(orgID string) (string, error)
GetOrganizations(limit int) (*contract.OrganizationsResponse, error)
GetDefaultOrgId(ctx context.Context) (orgID string, err error)
GetOrgIdFromSlug(ctx context.Context, slugName string) (string, error)
GetSlugFromOrgId(ctx context.Context, orgID string) (string, error)
GetOrganizations(ctx context.Context, limit int) (*contract.OrganizationsResponse, error)
Init(url string, client *http.Client)
GetFeatureFlag(flagname string, origId string) (bool, error)
GetUserMe() (string, error)
GetSelf() (contract.SelfResponse, error)
GetOrgSettings(orgId string) (*contract.OrgSettingsResponse, error)
GetFeatureFlag(ctx context.Context, flagname string, origId string) (bool, error)
GetUserMe(ctx context.Context) (string, error)
GetSelf(ctx context.Context) (contract.SelfResponse, error)
GetOrgSettings(ctx context.Context, orgId string) (*contract.OrgSettingsResponse, error)
}

var _ ApiClient = (*snykApiClient)(nil)
Expand All @@ -35,17 +36,18 @@ type snykApiClient struct {
// GetSlugFromOrgId retrieves the organization slug associated with a given Snyk organization ID.
//
// Parameters:
// - ctx: Context for cancellation and timeout control.
// - orgID (string): The UUID of the organization.
//
// Returns:
// - The organization slug as a string.
// - An error object (if the organization is not found, or if API request or response
// parsing errors occur).
func (a *snykApiClient) GetSlugFromOrgId(orgID string) (string, error) {
func (a *snykApiClient) GetSlugFromOrgId(ctx context.Context, orgID string) (string, error) {
endpoint := "/rest/orgs/" + orgID
version := "2024-03-12"

body, err := clientGet(a, endpoint, &version)
body, err := clientGet(ctx, a, endpoint, &version)
if err != nil {
return "", err
}
Expand All @@ -62,17 +64,18 @@ func (a *snykApiClient) GetSlugFromOrgId(orgID string) (string, error) {
// GetOrgIdFromSlug retrieves the organization ID associated with a given Snyk organization slug.
//
// Parameters:
// - ctx: Context for cancellation and timeout control.
// - slugName (string): The unique slug identifier of the organization.
//
// Returns:
// - The organization ID as a string.
// - An error object (if the organization is not found, or if API request or response
// parsing errors occur).
func (a *snykApiClient) GetOrgIdFromSlug(slugName string) (string, error) {
func (a *snykApiClient) GetOrgIdFromSlug(ctx context.Context, slugName string) (string, error) {
endpoint := "/rest/orgs"
version := "2024-03-12"

body, err := clientGet(a, endpoint, &version, "slug", slugName)
body, err := clientGet(ctx, a, endpoint, &version, "slug", slugName)
if err != nil {
return "", err
}
Expand All @@ -95,16 +98,17 @@ func (a *snykApiClient) GetOrgIdFromSlug(slugName string) (string, error) {
// GetOrganizations retrieves organizations accessible to the authenticated user.
//
// Parameters:
// - ctx: Context for cancellation and timeout control.
// - limit: Maximum number of organizations to return
//
// Returns:
// - A pointer to OrganizationsResponse containing organizations.
// - An error object (if an error occurred during the API request or response parsing).
func (a *snykApiClient) GetOrganizations(limit int) (*contract.OrganizationsResponse, error) {
func (a *snykApiClient) GetOrganizations(ctx context.Context, limit int) (*contract.OrganizationsResponse, error) {
endpoint := "/rest/orgs"
version := "2024-10-15"

body, err := clientGet(a, endpoint, &version, "limit", fmt.Sprintf("%d", limit))
body, err := clientGet(ctx, a, endpoint, &version, "limit", fmt.Sprintf("%d", limit))
if err != nil {
return nil, err
}
Expand All @@ -120,11 +124,14 @@ func (a *snykApiClient) GetOrganizations(limit int) (*contract.OrganizationsResp

// GetDefaultOrgId retrieves the default organization ID associated with the authenticated user.
//
// Parameters:
// - ctx: Context for cancellation and timeout control.
//
// Returns:
// - The user's default organization ID as a string.
// - An error object (if an error occurred while fetching user data).
func (a *snykApiClient) GetDefaultOrgId() (string, error) {
selfData, err := a.GetSelf()
func (a *snykApiClient) GetDefaultOrgId(ctx context.Context) (string, error) {
selfData, err := a.GetSelf(ctx)
if err != nil {
return "", fmt.Errorf("unable to retrieve org ID: %w", err)
}
Expand All @@ -134,11 +141,14 @@ func (a *snykApiClient) GetDefaultOrgId() (string, error) {

// GetUserMe retrieves the username for the authenticated user from the Snyk API.
//
// Parameters:
// - ctx: Context for cancellation and timeout control.
//
// Returns:
// - The authenticated user's username as a string.
// - An error object (if an error occurred while fetching user data or extracting the username).
func (a *snykApiClient) GetUserMe() (string, error) {
selfData, err := a.GetSelf()
func (a *snykApiClient) GetUserMe(ctx context.Context) (string, error) {
selfData, err := a.GetSelf(ctx)
if err != nil {
return "", fmt.Errorf("error while fetching self data: %w", err) // Prioritize error
}
Expand All @@ -161,14 +171,15 @@ func (a *snykApiClient) GetUserMe() (string, error) {
// GetFeatureFlag determines the state of a feature flag for the specified organization.
//
// Parameters:
// - ctx: Context for cancellation and timeout control.
// - flagname (string): The name of the feature flag to check.
// - orgId (string): The ID of the organization associated with the feature flag.
//
// Returns:
// - A boolean indicating if the feature flag is enabled (true) or disabled (false).
// - An error object (if an error occurred during the API request, response parsing,
// or if the organization ID is invalid).
func (a *snykApiClient) GetFeatureFlag(flagname string, orgId string) (bool, error) {
func (a *snykApiClient) GetFeatureFlag(ctx context.Context, flagname string, orgId string) (bool, error) {
const defaultResult = false

u := a.url + "/v1/cli-config/feature-flags/" + flagname + "?org=" + orgId
Expand All @@ -177,7 +188,12 @@ func (a *snykApiClient) GetFeatureFlag(flagname string, orgId string) (bool, err
return defaultResult, fmt.Errorf("failed to lookup feature flag with orgiId not set")
}

res, err := a.client.Get(u)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
if err != nil {
return defaultResult, fmt.Errorf("unable to create request: %w", err)
}

res, err := a.client.Do(req)
if err != nil {
return defaultResult, fmt.Errorf("unable to retrieve feature flag: %w", err)
}
Expand All @@ -204,14 +220,17 @@ func (a *snykApiClient) GetFeatureFlag(flagname string, orgId string) (bool, err

// GetSelf retrieves the authenticated user's information from the Snyk API.
//
// Parameters:
// - ctx: Context for cancellation and timeout control.
//
// Returns:
// - A `contract.SelfResponse` struct containing the user's data.
// - An error object (if an error occurred during the API request or response parsing).
func (a *snykApiClient) GetSelf() (contract.SelfResponse, error) {
func (a *snykApiClient) GetSelf(ctx context.Context) (contract.SelfResponse, error) {
endpoint := "/rest/self"
var selfData contract.SelfResponse

body, err := clientGet(a, endpoint, nil)
body, err := clientGet(ctx, a, endpoint, nil)
if err != nil {
return selfData, err
}
Expand All @@ -223,10 +242,15 @@ func (a *snykApiClient) GetSelf() (contract.SelfResponse, error) {
return selfData, nil
}

func (a *snykApiClient) GetOrgSettings(orgId string) (*contract.OrgSettingsResponse, error) {
func (a *snykApiClient) GetOrgSettings(ctx context.Context, orgId string) (*contract.OrgSettingsResponse, error) {
endpoint := fmt.Sprintf("%s/v1/org/%s/settings", a.url, url.QueryEscape(orgId))

res, err := a.client.Get(endpoint)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("unable to create request: %w", err)
}

res, err := a.client.Do(req)
if err != nil {
return nil, fmt.Errorf("unable to retrieve org settings: %w", err)
}
Expand All @@ -250,6 +274,7 @@ func (a *snykApiClient) GetOrgSettings(orgId string) (*contract.OrgSettingsRespo
// API versioning, and basic error checking.
//
// Parameters:
// - ctx: Context for cancellation and timeout control.
// - a (snykApiClient): A reference to the Snyk API client object.
// - endpoint (string): The endpoint path to be appended to the API base URL.
// - version (*string): An optional pointer to a string specifying the desired API version.
Expand All @@ -264,8 +289,8 @@ func (a *snykApiClient) GetOrgSettings(orgId string) (*contract.OrgSettingsRespo
//
// Example:
// apiVersion := "2022-01-12"
// response, err := clientGet(myApiClient, "/organizations", &apiVersion, "limit", "50")
func clientGet(a *snykApiClient, endpoint string, version *string, queryParams ...string) ([]byte, error) {
// response, err := clientGet(ctx, myApiClient, "/organizations", &apiVersion, "limit", "50")
func clientGet(ctx context.Context, a *snykApiClient, endpoint string, version *string, queryParams ...string) ([]byte, error) {
var apiVersion string = constants.SNYK_DEFAULT_API_VERSION
if version != nil && *version != "" {
apiVersion = *version
Expand All @@ -277,7 +302,12 @@ func clientGet(a *snykApiClient, endpoint string, version *string, queryParams .
return nil, err
}

res, err := a.client.Get(url.String())
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
if err != nil {
return nil, err
}

res, err := a.client.Do(req)
if err != nil {
return nil, err
}
Expand Down
14 changes: 7 additions & 7 deletions internal/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func Test_GetDefaultOrgId_ReturnsCorrectOrgId(t *testing.T) {
client := api.NewApi(server.URL, http.DefaultClient)

// Act
orgId, err := client.GetDefaultOrgId()
orgId, err := client.GetDefaultOrgId(t.Context())
if err != nil {
t.Error(err)
}
Expand All @@ -46,7 +46,7 @@ func Test_GetSlugFromOrgId_ReturnsCorrectSlug(t *testing.T) {
client := api.NewApi(server.URL, http.DefaultClient)

// Act
actualSlug, err := client.GetSlugFromOrgId(orgID)
actualSlug, err := client.GetSlugFromOrgId(t.Context(), orgID)
if err != nil {
t.Error(err)
}
Expand All @@ -66,7 +66,7 @@ func Test_GetOrganizations_ReturnsOrganizations(t *testing.T) {
client := api.NewApi(server.URL, http.DefaultClient)

// Act
response, err := client.GetOrganizations(limit)
response, err := client.GetOrganizations(t.Context(), limit)
if err != nil {
t.Error(err)
}
Expand All @@ -93,7 +93,7 @@ func Test_GetOrgIdFromSlug_ReturnsCorrectOrgId(t *testing.T) {
apiClient := api.NewApi(server.URL, http.DefaultClient)

// Act
orgId, err := apiClient.GetOrgIdFromSlug(slugName)
orgId, err := apiClient.GetOrgIdFromSlug(t.Context(), slugName)
if err != nil {
t.Error(err)
}
Expand All @@ -115,11 +115,11 @@ func Test_GetFeatureFlag_false(t *testing.T) {
server := setupSingleReponseServer(t, "/v1/cli-config/feature-flags/"+featureFlagName+"?org="+org, featureFlagResponse)
client := api.NewApi(server.URL, http.DefaultClient)

actual, err := client.GetFeatureFlag(featureFlagName, org)
actual, err := client.GetFeatureFlag(t.Context(), featureFlagName, org)
assert.NoError(t, err)
assert.False(t, actual)

actual, err = client.GetFeatureFlag("unknownFF", org)
actual, err = client.GetFeatureFlag(t.Context(), "unknownFF", org)
assert.Error(t, err)
assert.False(t, actual)
}
Expand All @@ -137,7 +137,7 @@ func Test_GetFeatureFlag_true(t *testing.T) {
server := setupSingleReponseServer(t, "/v1/cli-config/feature-flags/"+featureFlagName+"?org="+org, featureFlagResponse)
client := api.NewApi(server.URL, http.DefaultClient)

actual, err := client.GetFeatureFlag(featureFlagName, org)
actual, err := client.GetFeatureFlag(t.Context(), featureFlagName, org)
assert.NoError(t, err)
assert.True(t, actual)
}
Expand Down
Loading