|
| 1 | +package checks |
| 2 | + |
| 3 | +import ( |
| 4 | + "fmt" |
| 5 | + "math" |
| 6 | + "reflect" |
| 7 | + "strconv" |
| 8 | + "strings" |
| 9 | + |
| 10 | + api "github.com/bootdotdev/bootdev/client" |
| 11 | +) |
| 12 | + |
| 13 | +func LocalSubmissionEvent(cliData api.CLIData, results []api.CLIStepResult) api.LessonSubmissionEvent { |
| 14 | + failure := EvaluateCLIResults(cliData, results) |
| 15 | + slug := api.VerificationResultSlugSuccess |
| 16 | + if failure != nil { |
| 17 | + slug = api.VerificationResultSlugFailure |
| 18 | + if failure.FailedStepIndex >= 0 && |
| 19 | + failure.FailedStepIndex < len(cliData.Steps) && |
| 20 | + cliData.Steps[failure.FailedStepIndex].NoPenaltyOnFail { |
| 21 | + slug = api.VerificationResultSlugNoop |
| 22 | + } |
| 23 | + } |
| 24 | + |
| 25 | + return api.LessonSubmissionEvent{ |
| 26 | + ResultSlug: slug, |
| 27 | + StructuredErrCLI: failure, |
| 28 | + XPReward: -1, |
| 29 | + } |
| 30 | +} |
| 31 | + |
| 32 | +func EvaluateCLIResults(cliData api.CLIData, results []api.CLIStepResult) *api.StructuredErrCLI { |
| 33 | + for stepIndex, step := range cliData.Steps { |
| 34 | + if stepIndex >= len(results) { |
| 35 | + return localFailure(stepIndex, 0, "missing result for step") |
| 36 | + } |
| 37 | + |
| 38 | + switch { |
| 39 | + case step.CLICommand != nil: |
| 40 | + result := results[stepIndex].CLICommandResult |
| 41 | + if result == nil { |
| 42 | + return localFailure(stepIndex, 0, "missing CLI command result") |
| 43 | + } |
| 44 | + if failure := evaluateCLICommandTests(stepIndex, *step.CLICommand, *result); failure != nil { |
| 45 | + return failure |
| 46 | + } |
| 47 | + case step.HTTPRequest != nil: |
| 48 | + result := results[stepIndex].HTTPRequestResult |
| 49 | + if result == nil { |
| 50 | + return localFailure(stepIndex, 0, "missing HTTP request result") |
| 51 | + } |
| 52 | + if failure := evaluateHTTPRequestTests(stepIndex, *step.HTTPRequest, *result); failure != nil { |
| 53 | + return failure |
| 54 | + } |
| 55 | + default: |
| 56 | + return localFailure(stepIndex, 0, "missing step definition") |
| 57 | + } |
| 58 | + } |
| 59 | + |
| 60 | + return nil |
| 61 | +} |
| 62 | + |
| 63 | +func evaluateCLICommandTests(stepIndex int, cmd api.CLIStepCLICommand, result api.CLICommandResult) *api.StructuredErrCLI { |
| 64 | + for testIndex, test := range cmd.Tests { |
| 65 | + var err error |
| 66 | + |
| 67 | + switch { |
| 68 | + case test.ExitCode != nil: |
| 69 | + if result.ExitCode != *test.ExitCode { |
| 70 | + err = fmt.Errorf("expected exit code %d, got %d", *test.ExitCode, result.ExitCode) |
| 71 | + } |
| 72 | + case len(test.StdoutContainsAll) > 0: |
| 73 | + for _, contains := range test.StdoutContainsAll { |
| 74 | + needle := InterpolateVariables(contains, result.Variables) |
| 75 | + if !strings.Contains(result.Stdout, needle) { |
| 76 | + err = fmt.Errorf("expected stdout to contain %q", needle) |
| 77 | + break |
| 78 | + } |
| 79 | + } |
| 80 | + case len(test.StdoutContainsNone) > 0: |
| 81 | + for _, containsNone := range test.StdoutContainsNone { |
| 82 | + needle := InterpolateVariables(containsNone, result.Variables) |
| 83 | + if strings.Contains(result.Stdout, needle) { |
| 84 | + err = fmt.Errorf("expected stdout to not contain %q", needle) |
| 85 | + break |
| 86 | + } |
| 87 | + } |
| 88 | + case test.StdoutLinesGt != nil: |
| 89 | + lineCount := stdoutLineCount(result.Stdout) |
| 90 | + if lineCount <= *test.StdoutLinesGt { |
| 91 | + err = fmt.Errorf("expected stdout to have more than %d lines, got %d", *test.StdoutLinesGt, lineCount) |
| 92 | + } |
| 93 | + case test.StdoutJq != nil: |
| 94 | + err = evaluateStdoutJq(result.Stdout, *test.StdoutJq, result.Variables) |
| 95 | + default: |
| 96 | + err = fmt.Errorf("unsupported CLI command test") |
| 97 | + } |
| 98 | + |
| 99 | + if err != nil { |
| 100 | + return localFailure(stepIndex, testIndex, err.Error()) |
| 101 | + } |
| 102 | + } |
| 103 | + |
| 104 | + return nil |
| 105 | +} |
| 106 | + |
| 107 | +func evaluateHTTPRequestTests(stepIndex int, req api.CLIStepHTTPRequest, result api.HTTPRequestResult) *api.StructuredErrCLI { |
| 108 | + if result.Err != "" { |
| 109 | + return localFailure(stepIndex, 0, result.Err) |
| 110 | + } |
| 111 | + |
| 112 | + for testIndex, test := range req.Tests { |
| 113 | + var err error |
| 114 | + |
| 115 | + switch { |
| 116 | + case test.StatusCode != nil: |
| 117 | + if result.StatusCode != *test.StatusCode { |
| 118 | + err = fmt.Errorf("expected status code %d, got %d", *test.StatusCode, result.StatusCode) |
| 119 | + } |
| 120 | + case test.BodyContains != nil: |
| 121 | + needle := InterpolateVariables(*test.BodyContains, result.Variables) |
| 122 | + if !strings.Contains(result.BodyString, needle) { |
| 123 | + err = fmt.Errorf("expected body to contain %q", needle) |
| 124 | + } |
| 125 | + case test.BodyContainsNone != nil: |
| 126 | + needle := InterpolateVariables(*test.BodyContainsNone, result.Variables) |
| 127 | + if strings.Contains(result.BodyString, needle) { |
| 128 | + err = fmt.Errorf("expected body to not contain %q", needle) |
| 129 | + } |
| 130 | + case test.HeadersContain != nil: |
| 131 | + err = evaluateHeaderContains(result.ResponseHeaders, *test.HeadersContain, result.Variables, "header") |
| 132 | + case test.TrailersContain != nil: |
| 133 | + err = evaluateHeaderContains(result.ResponseTrailers, *test.TrailersContain, result.Variables, "trailer") |
| 134 | + case test.JSONValue != nil: |
| 135 | + err = evaluateHTTPJSONValue(result.BodyString, *test.JSONValue, result.Variables) |
| 136 | + default: |
| 137 | + err = fmt.Errorf("unsupported HTTP request test") |
| 138 | + } |
| 139 | + |
| 140 | + if err != nil { |
| 141 | + return localFailure(stepIndex, testIndex, err.Error()) |
| 142 | + } |
| 143 | + } |
| 144 | + |
| 145 | + return nil |
| 146 | +} |
| 147 | + |
| 148 | +func evaluateHeaderContains(headers map[string]string, test api.HTTPRequestTestHeader, variables map[string]string, label string) error { |
| 149 | + key := InterpolateVariables(test.Key, variables) |
| 150 | + want := InterpolateVariables(test.Value, variables) |
| 151 | + |
| 152 | + got, ok := findHeaderValue(headers, key) |
| 153 | + if !ok { |
| 154 | + return fmt.Errorf("expected %s %q to exist", label, key) |
| 155 | + } |
| 156 | + if !strings.Contains(got, want) { |
| 157 | + return fmt.Errorf("expected %s %q to contain %q, got %q", label, key, want, got) |
| 158 | + } |
| 159 | + |
| 160 | + return nil |
| 161 | +} |
| 162 | + |
| 163 | +func evaluateHTTPJSONValue(body string, test api.HTTPRequestTestJSONValue, variables map[string]string) error { |
| 164 | + got, err := valFromJqPath(test.Path, body) |
| 165 | + if err != nil { |
| 166 | + return err |
| 167 | + } |
| 168 | + |
| 169 | + want, err := httpJSONExpectedValue(test, variables) |
| 170 | + if err != nil { |
| 171 | + return err |
| 172 | + } |
| 173 | + |
| 174 | + if !compareValues(got, test.Operator, want) { |
| 175 | + return fmt.Errorf("expected JSON at %s %s %v, got %v", test.Path, test.Operator, want, got) |
| 176 | + } |
| 177 | + |
| 178 | + return nil |
| 179 | +} |
| 180 | + |
| 181 | +func httpJSONExpectedValue(test api.HTTPRequestTestJSONValue, variables map[string]string) (any, error) { |
| 182 | + switch { |
| 183 | + case test.IntValue != nil: |
| 184 | + return *test.IntValue, nil |
| 185 | + case test.StringValue != nil: |
| 186 | + return InterpolateVariables(*test.StringValue, variables), nil |
| 187 | + case test.BoolValue != nil: |
| 188 | + return *test.BoolValue, nil |
| 189 | + default: |
| 190 | + return nil, fmt.Errorf("missing expected JSON value") |
| 191 | + } |
| 192 | +} |
| 193 | + |
| 194 | +func evaluateStdoutJq(stdout string, test api.StdoutJqTest, variables map[string]string) error { |
| 195 | + queryText := InterpolateVariables(test.Query, variables) |
| 196 | + |
| 197 | + input, err := parseJqInput(stdout, test.InputMode) |
| 198 | + if err != nil { |
| 199 | + return err |
| 200 | + } |
| 201 | + |
| 202 | + results, err := executeJqQuery(queryText, input) |
| 203 | + if err != nil { |
| 204 | + return err |
| 205 | + } |
| 206 | + if len(results) != len(test.ExpectedResults) { |
| 207 | + return fmt.Errorf("expected jq query %q to return %d result(s), got %d", queryText, len(test.ExpectedResults), len(results)) |
| 208 | + } |
| 209 | + |
| 210 | + for i, expected := range test.ExpectedResults { |
| 211 | + want, err := jqExpectedValue(expected, variables) |
| 212 | + if err != nil { |
| 213 | + return err |
| 214 | + } |
| 215 | + if !compareValues(results[i], api.OperatorType(expected.Operator), want) { |
| 216 | + return fmt.Errorf("expected jq result %d to be %s %v, got %v", i+1, expected.Operator, want, results[i]) |
| 217 | + } |
| 218 | + } |
| 219 | + |
| 220 | + return nil |
| 221 | +} |
| 222 | + |
| 223 | +func jqExpectedValue(expected api.JqExpectedResult, variables map[string]string) (any, error) { |
| 224 | + switch expected.Type { |
| 225 | + case api.JqTypeString: |
| 226 | + if str, ok := expected.Value.(string); ok { |
| 227 | + return InterpolateVariables(str, variables), nil |
| 228 | + } |
| 229 | + return expected.Value, nil |
| 230 | + case api.JqTypeInt: |
| 231 | + if str, ok := expected.Value.(string); ok { |
| 232 | + parsed, err := strconv.Atoi(InterpolateVariables(str, variables)) |
| 233 | + if err != nil { |
| 234 | + return nil, err |
| 235 | + } |
| 236 | + return parsed, nil |
| 237 | + } |
| 238 | + return expected.Value, nil |
| 239 | + case api.JqTypeBool: |
| 240 | + if str, ok := expected.Value.(string); ok { |
| 241 | + parsed, err := strconv.ParseBool(InterpolateVariables(str, variables)) |
| 242 | + if err != nil { |
| 243 | + return nil, err |
| 244 | + } |
| 245 | + return parsed, nil |
| 246 | + } |
| 247 | + return expected.Value, nil |
| 248 | + default: |
| 249 | + return nil, fmt.Errorf("unsupported jq expected result type %q", expected.Type) |
| 250 | + } |
| 251 | +} |
| 252 | + |
| 253 | +func compareValues(got any, operator api.OperatorType, want any) bool { |
| 254 | + switch operator { |
| 255 | + case api.OpEquals, "==": |
| 256 | + return valuesEqual(got, want) |
| 257 | + case api.OpGreaterThan, ">": |
| 258 | + gotNum, gotOK := numberValue(got) |
| 259 | + wantNum, wantOK := numberValue(want) |
| 260 | + return gotOK && wantOK && gotNum > wantNum |
| 261 | + case api.OpContains: |
| 262 | + return strings.Contains(fmt.Sprintf("%v", got), fmt.Sprintf("%v", want)) |
| 263 | + case api.OpNotContains: |
| 264 | + return !strings.Contains(fmt.Sprintf("%v", got), fmt.Sprintf("%v", want)) |
| 265 | + default: |
| 266 | + return false |
| 267 | + } |
| 268 | +} |
| 269 | + |
| 270 | +func valuesEqual(got any, want any) bool { |
| 271 | + if gotNum, gotOK := numberValue(got); gotOK { |
| 272 | + wantNum, wantOK := numberValue(want) |
| 273 | + return wantOK && math.Abs(gotNum-wantNum) < 0.000000001 |
| 274 | + } |
| 275 | + return reflect.DeepEqual(got, want) |
| 276 | +} |
| 277 | + |
| 278 | +func numberValue(value any) (float64, bool) { |
| 279 | + switch v := value.(type) { |
| 280 | + case int: |
| 281 | + return float64(v), true |
| 282 | + case int64: |
| 283 | + return float64(v), true |
| 284 | + case float64: |
| 285 | + return v, true |
| 286 | + case jsonNumber: |
| 287 | + parsed, err := strconv.ParseFloat(v.String(), 64) |
| 288 | + return parsed, err == nil |
| 289 | + default: |
| 290 | + return 0, false |
| 291 | + } |
| 292 | +} |
| 293 | + |
| 294 | +func stdoutLineCount(stdout string) int { |
| 295 | + if stdout == "" { |
| 296 | + return 0 |
| 297 | + } |
| 298 | + return strings.Count(stdout, "\n") + 1 |
| 299 | +} |
| 300 | + |
| 301 | +func localFailure(stepIndex int, testIndex int, message string) *api.StructuredErrCLI { |
| 302 | + return &api.StructuredErrCLI{ |
| 303 | + ErrorMessage: message, |
| 304 | + FailedStepIndex: stepIndex, |
| 305 | + FailedTestIndex: testIndex, |
| 306 | + } |
| 307 | +} |
| 308 | + |
| 309 | +type jsonNumber interface { |
| 310 | + String() string |
| 311 | +} |
0 commit comments