Skip to content

Commit 5416688

Browse files
wynbennettci.datadog-api-spec
andauthored
Add ability for client to use delegated token authentication (DataDog#3133)
* Add ability for client to use delegated token authentication in place of api and app key authentication * Cleaned up code quality violation * Trim the api prefix if it is present in the name serverVars * Added tests and general code cleanup * Updated how we pass the token to the subsequent requests * pre-commit fixes * Update code generation for client * pre-commit fixes * Fixed non-standard indentation * Refactored to structure to put delegated auth into the datadog package. Added templates for aws.go and delegated_auth.go * Fixed test * Refactored how we get the server URL for delegated auth to use the standard ServerConfigurations * pre-commit fixes * Cleaned up generator and test * AUTHN-4821 - Refactored delegated auth to use not use the context as much and not to hijack the SetAuthKeys * AUTHN-4821 - Cleaned up template code * pre-commit fixes * AUTHN-4821 - Updated AWS variable naming * AUTHN-4821 - Fixed issues with context updating * AUTHN-4821 - Fixed AWS example * AUTHN-4821 - Code cleanup, comments and moving a call * pre-commit fixes --------- Co-authored-by: ci.datadog-api-spec <packages@datadoghq.com>
1 parent 5575c65 commit 5416688

120 files changed

Lines changed: 11499 additions & 4551 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.generator/conftest.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,26 @@
11
# coding=utf-8
22
"""Define basic fixtures."""
33

4+
import hashlib
45
import json
56
import os
67
import pathlib
8+
import pytest
79
import re
810
import warnings
911
import zlib
1012
from collections import defaultdict
11-
12-
import pytest
1313
from dateutil.relativedelta import relativedelta
14-
from jinja2 import Environment, FileSystemLoader, Template
15-
from pytest_bdd import given, parsers, then, when
16-
import hashlib
17-
1814
from generator import openapi
15+
from generator.formatter import format_parameters, format_data_with_schema, go_name
1916
from generator.utils import (
2017
camel_case,
2118
given_variables,
2219
snake_case,
2320
untitle_case,
2421
)
25-
26-
from generator.formatter import format_parameters, format_data_with_schema, go_name
27-
22+
from jinja2 import Environment, FileSystemLoader, Template
23+
from pytest_bdd import given, parsers, then, when
2824

2925
MODIFIED_FEATURES = {pathlib.Path(p).resolve() for p in os.getenv("BDD_MODIFIED_FEATURES", "").split(" ") if p}
3026

@@ -74,6 +70,9 @@ def lookup(value, path):
7470
JINJA_ENV.globals["given_variables"] = given_variables
7571

7672
GO_EXAMPLE_J2 = JINJA_ENV.get_template("example.j2")
73+
DATADOG_EXAMPLES_J2 = {
74+
"aws.go": JINJA_ENV.get_template("example_aws.j2")
75+
}
7776

7877

7978
def pytest_bdd_after_scenario(request, feature, scenario):
@@ -113,6 +112,19 @@ def pytest_bdd_after_scenario(request, feature, scenario):
113112
with output.open("w") as f:
114113
f.write(data)
115114

115+
for file_name, template in DATADOG_EXAMPLES_J2.items():
116+
output = ROOT_PATH / "examples" / "datadog" / file_name
117+
output.parent.mkdir(parents=True, exist_ok=True)
118+
119+
data = template.render(
120+
context=context,
121+
version=version,
122+
scenario=scenario,
123+
operation_spec=operation_spec.spec,
124+
)
125+
with output.open("w") as f:
126+
f.write(data)
127+
116128

117129
def pytest_bdd_apply_tag(tag, function):
118130
"""Register tags as custom markers and skip test for '@skip' ones."""

.generator/src/generator/cli.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
import pathlib
2-
31
import click
2+
import pathlib
43
from jinja2 import Environment, FileSystemLoader
54

6-
from . import openapi
75
from . import formatter
6+
from . import openapi
87
from . import utils
98

109
MODULE = "github.com/DataDog/datadog-api-client-go/v2"
@@ -67,8 +66,10 @@ def cli(specs, output):
6766
doc_j2 = env.get_template("doc.j2")
6867

6968
extra_files = {
69+
"aws.go": env.get_template("aws.j2"),
7070
"client.go": env.get_template("client.j2"),
7171
"configuration.go": env.get_template("configuration.j2"),
72+
"delegated_auth.go": env.get_template("delegated_auth.j2"),
7273
"utils.go": env.get_template("utils.j2"),
7374
"zstd.go": env.get_template("zstd.j2"),
7475
"no_zstd.go": env.get_template("no_zstd.j2"),

.generator/src/generator/templates/api.j2

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,23 @@ localVarQueryParams.Add("{{ parameter.name }}", {{ common_package_name }}.Parame
241241
{%- endif %}
242242

243243
{%- set authMethods = operation.security if "security" in operation else openapi.security %}
244+
{%- set appKeyNs = namespace(hasAppKeyAuth=false) %}
245+
{%- for authMethod in authMethods %}
246+
{%- for name in authMethod %}
247+
{%- if name == "appKeyAuth" %}
248+
{%- set appKeyNs.hasAppKeyAuth = true %}
249+
{%- endif %}
250+
{%- endfor %}
251+
{%- endfor %}
244252
{%- if authMethods %}
253+
{%- if appKeyNs.hasAppKeyAuth %}
254+
if a.Client.Cfg.DelegatedTokenConfig != nil {
255+
err = {{ common_package_name }}.UseDelegatedTokenAuth(ctx, &localVarHeaderParams, a.Client.Cfg.DelegatedTokenConfig)
256+
if err != nil {
257+
return {% if returnType %}localVarReturnValue, {% endif %}nil, err
258+
}
259+
} else {
260+
{%- endif %}
245261
{{ common_package_name }}.SetAuthKeys(
246262
ctx,
247263
&localVarHeaderParams,
@@ -254,6 +270,7 @@ localVarQueryParams.Add("{{ parameter.name }}", {{ common_package_name }}.Parame
254270
{%- endfor %}
255271
{%- endfor %}
256272
)
273+
{% if appKeyNs.hasAppKeyAuth %} } {% endif %}
257274
{%- endif %}
258275
req, err := a.Client.PrepareRequest(ctx, localVarPath, localVarHTTPMethod, localVarPostBody, localVarHeaderParams, localVarQueryParams, localVarFormParams, {% if formParameter %}&formFile{% else %}nil{% endif %})
259276
if err != nil {
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
{% include "partial_header.j2" %}
2+
package {{ common_package_name }}
3+
4+
import (
5+
"context"
6+
"crypto/hmac"
7+
"crypto/sha256"
8+
"encoding/base64"
9+
"encoding/hex"
10+
"encoding/json"
11+
"fmt"
12+
"net/http"
13+
"os"
14+
"sort"
15+
"strings"
16+
"time"
17+
)
18+
19+
type Credentials struct {
20+
AccessKeyID string
21+
SecretAccessKey string
22+
SessionToken string
23+
}
24+
25+
// SigningData is the data structure that represents the Data used to generate and AWS Proof
26+
type SigningData struct {
27+
HeadersEncoded string `json:"iam_headers_encoded"`
28+
BodyEncoded string `json:"iam_body_encoded"`
29+
URLEncoded string `json:"iam_url_encoded"`
30+
Method string `json:"iam_method"`
31+
}
32+
33+
const (
34+
// Common Headers
35+
// orgIdHeader is the header we use to specify the name of the org we request a token for
36+
orgIdHeader = "x-ddog-org-id"
37+
hostHeader = "host"
38+
applicationForm = "application/x-www-form-urlencoded; charset=utf-8"
39+
40+
// AWS specific constants
41+
AWSAccessKeyIdName = "AWS_ACCESS_KEY_ID"
42+
AWSSecretAccessKeyName = "AWS_SECRET_ACCESS_KEY"
43+
AWSSessionTokenName = "AWS_SESSION_TOKEN"
44+
45+
amzDateHeader = "X-Amz-Date"
46+
amzTokenHeader = "X-Amz-Security-Token"
47+
amzDateFormat = "20060102"
48+
amzDateTimeFormat = "20060102T150405Z"
49+
defaultRegion = "us-east-1"
50+
defaultStsHost = "sts.amazonaws.com"
51+
regionalStsHost = "sts.%s.amazonaws.com"
52+
service = "sts"
53+
algorithm = "AWS4-HMAC-SHA256"
54+
aws4Request = "aws4_request"
55+
getCallerIdentityBody = "Action=GetCallerIdentity&Version=2011-06-15"
56+
)
57+
58+
const ProviderAWS = "aws"
59+
60+
type AWSAuth struct {
61+
AwsRegion string
62+
}
63+
64+
func (a *AWSAuth) Authenticate(ctx context.Context, config *DelegatedTokenConfig) (*DelegatedTokenCredentials, error) {
65+
// Get local AWS Credentials
66+
creds := a.GetCredentials(ctx)
67+
68+
if config == nil || config.OrgUUID == "" {
69+
return nil, fmt.Errorf("missing org UUID in config")
70+
}
71+
72+
// Use the credentials to generate the signing data
73+
data, err := a.GenerateAwsAuthData(config.OrgUUID, creds)
74+
if err != nil {
75+
return nil, err
76+
}
77+
78+
// Generate the auth string passed to the token endpoint
79+
authString := data.BodyEncoded + "|" + data.HeadersEncoded + "|" + data.Method + "|" + data.URLEncoded
80+
81+
authResponse, err := GetDelegatedToken(ctx, config.OrgUUID, authString)
82+
return authResponse, err
83+
}
84+
85+
func (a *AWSAuth) GetCredentials(ctx context.Context) *Credentials {
86+
keys := ctx.Value(ContextAWSVariables)
87+
if keys != nil {
88+
keysMap := keys.(map[string]string)
89+
creds := Credentials{}
90+
if accessKey, ok := keysMap[AWSAccessKeyIdName]; ok {
91+
creds.AccessKeyID = accessKey
92+
}
93+
if secretKey, ok := keysMap[AWSSecretAccessKeyName]; ok {
94+
creds.SecretAccessKey = secretKey
95+
}
96+
if sessionToken, ok := keysMap[AWSSessionTokenName]; ok {
97+
creds.SessionToken = sessionToken
98+
}
99+
return &creds
100+
} else {
101+
accessKey := os.Getenv(AWSAccessKeyIdName)
102+
secretKey := os.Getenv(AWSSecretAccessKeyName)
103+
sessionToken := os.Getenv(AWSSessionTokenName)
104+
return &Credentials{
105+
AccessKeyID: accessKey,
106+
SecretAccessKey: secretKey,
107+
SessionToken: sessionToken,
108+
}
109+
}
110+
}
111+
112+
func (a *AWSAuth) getConnectionParameters() (string, string, string) {
113+
region := a.AwsRegion
114+
var host string
115+
// Default to the default global STS Host (see here: https://docs.aws.amazon.com/general/latest/gr/sts.html)
116+
if region == "" {
117+
region = defaultRegion
118+
host = defaultStsHost
119+
} else {
120+
// If the region is not empty, use the regional STS host
121+
host = fmt.Sprintf(regionalStsHost, region)
122+
}
123+
stsFullURL := fmt.Sprintf("https://%s", host)
124+
return stsFullURL, region, host
125+
}
126+
127+
func (a *AWSAuth) GenerateAwsAuthData(orgUUID string, creds *Credentials) (*SigningData, error) {
128+
if orgUUID == "" {
129+
return nil, fmt.Errorf("missing org UUID")
130+
}
131+
if creds == nil || (creds.AccessKeyID == "" && creds.SecretAccessKey == "") || creds.SessionToken == "" {
132+
return nil, fmt.Errorf("missing AWS credentials")
133+
}
134+
stsFullURL, region, host := a.getConnectionParameters()
135+
136+
now := time.Now().UTC()
137+
138+
requestBody := getCallerIdentityBody
139+
h := sha256.Sum256([]byte(requestBody))
140+
payloadHash := hex.EncodeToString(h[:])
141+
142+
// Create the headers that factor into the signing algorithm
143+
headerMap := map[string][]string{
144+
contextLengthHeader: {
145+
fmt.Sprintf("%d", len(requestBody)),
146+
},
147+
contentTypeHeader: {
148+
applicationForm,
149+
},
150+
amzDateHeader: {
151+
now.Format(amzDateTimeFormat),
152+
},
153+
orgIdHeader: {
154+
orgUUID,
155+
},
156+
amzTokenHeader: {
157+
creds.SessionToken,
158+
},
159+
hostHeader: {
160+
host,
161+
},
162+
}
163+
164+
headerArr := make([]string, len(headerMap), len(headerMap))
165+
signedHeadersArr := make([]string, len(headerMap), len(headerMap))
166+
headerIdx := 0
167+
for k, v := range headerMap {
168+
loweredHeaderName := strings.ToLower(k)
169+
headerArr[headerIdx] = fmt.Sprintf("%s:%s", loweredHeaderName, strings.Join(v, ","))
170+
signedHeadersArr[headerIdx] = loweredHeaderName
171+
headerIdx++
172+
}
173+
sort.Strings(headerArr)
174+
sort.Strings(signedHeadersArr)
175+
signedHeaders := strings.Join(signedHeadersArr, ";")
176+
177+
canonicalRequest := strings.Join([]string{
178+
http.MethodPost,
179+
"/",
180+
"", // No query string
181+
strings.Join(headerArr, "\n") + "\n",
182+
signedHeaders,
183+
payloadHash,
184+
}, "\n")
185+
186+
// Create the string to sign
187+
hashCanonicalRequest := sha256.Sum256([]byte(canonicalRequest))
188+
credentialScope := strings.Join([]string{
189+
now.Format(amzDateFormat),
190+
region,
191+
service,
192+
aws4Request,
193+
}, "/")
194+
stringToSign := a.makeSignature(
195+
now,
196+
credentialScope,
197+
hex.EncodeToString(hashCanonicalRequest[:]),
198+
region,
199+
service,
200+
creds.SecretAccessKey,
201+
algorithm,
202+
)
203+
204+
// Create the authorization header
205+
credential := strings.Join([]string{
206+
creds.AccessKeyID,
207+
credentialScope,
208+
}, "/")
209+
authHeader := fmt.Sprintf("%s Credential=%s, SignedHeaders=%s, Signature=%s",
210+
algorithm, credential, signedHeaders, stringToSign)
211+
212+
headerMap["Authorization"] = []string{authHeader}
213+
headerMap["User-Agent"] = []string{GetUserAgent()}
214+
headersJSON, err := json.Marshal(headerMap)
215+
if err != nil {
216+
return nil, err
217+
}
218+
219+
return &SigningData{
220+
HeadersEncoded: base64.StdEncoding.EncodeToString(headersJSON),
221+
BodyEncoded: base64.StdEncoding.EncodeToString([]byte(requestBody)),
222+
Method: http.MethodPost,
223+
URLEncoded: base64.StdEncoding.EncodeToString([]byte(stsFullURL)),
224+
}, nil
225+
}
226+
227+
func (a *AWSAuth) makeSignature(t time.Time, credentialScope, payloadHash, region, service, secretAccessKey, algorithm string) string {
228+
// Create the string to sign
229+
stringToSign := strings.Join([]string{
230+
algorithm,
231+
t.Format(amzDateTimeFormat),
232+
credentialScope,
233+
payloadHash,
234+
}, "\n")
235+
236+
// Create the signing key
237+
kDate := hmac256(t.Format(amzDateFormat), []byte("AWS4"+secretAccessKey))
238+
kRegion := hmac256(region, kDate)
239+
kService := hmac256(service, kRegion)
240+
kSigning := hmac256(aws4Request, kService)
241+
242+
// Sign the string
243+
signature := hex.EncodeToString(hmac256(stringToSign, kSigning))
244+
245+
return signature
246+
}
247+
248+
func hmac256(data string, key []byte) []byte {
249+
h := hmac.New(sha256.New, key)
250+
h.Write([]byte(data))
251+
return h.Sum(nil)
252+
}

0 commit comments

Comments
 (0)