From 59df7bff5cafc747afab24ccce5a150cf22c2a13 Mon Sep 17 00:00:00 2001 From: Yann Hamon Date: Sun, 11 May 2025 03:57:09 +0200 Subject: [PATCH] Add tests for the HTTP loader Add another test case, remove accidental double memory caching --- pkg/loader/file.go | 17 +-- pkg/loader/http_test.go | 216 ++++++++++++++++++++++++++++++++++++++ pkg/registry/http_test.go | 175 ------------------------------ pkg/registry/registry.go | 4 +- 4 files changed, 219 insertions(+), 193 deletions(-) create mode 100644 pkg/loader/http_test.go delete mode 100644 pkg/registry/http_test.go diff --git a/pkg/loader/file.go b/pkg/loader/file.go index 076324e..9f76eba 100644 --- a/pkg/loader/file.go +++ b/pkg/loader/file.go @@ -24,11 +24,6 @@ func (l FileLoader) Load(url string) (any, error) { if err != nil { return nil, err } - if l.cache != nil { - if cached, err := l.cache.Get(path); err == nil { - return jsonschema.UnmarshalJSON(bytes.NewReader(cached.([]byte))) - } - } f, err := os.Open(path) if err != nil { @@ -45,12 +40,6 @@ func (l FileLoader) Load(url string) (any, error) { return nil, err } - if l.cache != nil { - if err = l.cache.Set(path, content); err != nil { - return nil, fmt.Errorf("failed to write cache to disk: %s", err) - } - } - return jsonschema.UnmarshalJSON(bytes.NewReader(content)) } @@ -71,8 +60,6 @@ func (l FileLoader) ToFile(url string) (string, error) { return path, nil } -func NewFileLoader(cache cache.Cache) *FileLoader { - return &FileLoader{ - cache: cache, - } +func NewFileLoader() *FileLoader { + return &FileLoader{} } diff --git a/pkg/loader/http_test.go b/pkg/loader/http_test.go new file mode 100644 index 0000000..0846bf8 --- /dev/null +++ b/pkg/loader/http_test.go @@ -0,0 +1,216 @@ +package loader + +import ( + "errors" + "fmt" + "net/http" + "net/http/httptest" + "sync" + "testing" +) + +type mockCache struct { + data map[string]any +} + +func (m *mockCache) Get(key string) (any, error) { + if val, ok := m.data[key]; ok { + return val, nil + } + return nil, errors.New("cache miss") +} + +func (m *mockCache) Set(key string, value any) error { + m.data[key] = value + return nil +} + +// Test basic functionality of HTTPURLLoader +func TestHTTPURLLoader_Load(t *testing.T) { + tests := []struct { + name string + mockResponse string + mockStatusCode int + cacheEnabled bool + expectError bool + expectCacheHit bool + }{ + { + name: "successful load", + mockResponse: `{"type": "object"}`, + mockStatusCode: http.StatusOK, + cacheEnabled: false, + expectError: false, + }, + { + name: "not found error", + mockResponse: "", + mockStatusCode: http.StatusNotFound, + cacheEnabled: false, + expectError: true, + }, + { + name: "server error", + mockResponse: "", + mockStatusCode: http.StatusInternalServerError, + cacheEnabled: false, + expectError: true, + }, + { + name: "cache hit", + mockResponse: `{"type": "object"}`, + mockStatusCode: http.StatusOK, + cacheEnabled: true, + expectError: false, + expectCacheHit: true, + }, + { + name: "Partial response from server", + mockResponse: `{"type": "objec`, + mockStatusCode: http.StatusOK, + cacheEnabled: false, + expectError: true, + expectCacheHit: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Mock HTTP server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tt.mockStatusCode) + w.Write([]byte(tt.mockResponse)) + })) + defer server.Close() + + // Create HTTPURLLoader + loader := &HTTPURLLoader{ + client: *server.Client(), + cache: nil, + } + + if tt.cacheEnabled { + loader.cache = &mockCache{data: map[string]any{}} + if tt.expectCacheHit { + loader.cache.Set(server.URL, []byte(tt.mockResponse)) + } + } + + // Call Load and handle errors + res, err := loader.Load(server.URL) + if tt.expectError { + if err == nil { + t.Errorf("expected error, got nil") + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if res == nil { + t.Errorf("expected non-nil result, got nil") + } + } + + }) + } +} + +// Test basic functionality of HTTPURLLoader +func TestHTTPURLLoader_Load_Retries(t *testing.T) { + + tests := []struct { + name string + url string + expectError bool + expectCallCount int + consecutiveFailures int + }{ + { + name: "retries on 503", + url: "/503", + expectError: false, + expectCallCount: 2, + consecutiveFailures: 2, + }, + { + name: "fails when hitting max retries", + url: "/503", + expectError: true, + expectCallCount: 3, + consecutiveFailures: 5, + }, + { + name: "retry on connection reset", + url: "/simulate-reset", + expectError: false, + expectCallCount: 2, + consecutiveFailures: 1, + }, + { + name: "retry on connection reset", + url: "/simulate-reset", + expectError: true, + expectCallCount: 3, + consecutiveFailures: 5, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ccMutex := &sync.Mutex{} + callCounts := map[string]int{} + // Mock HTTP server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ccMutex.Lock() + callCounts[r.URL.Path]++ + callCount := callCounts[r.URL.Path] + ccMutex.Unlock() + + switch r.URL.Path { + case "/simulate-reset": + if callCount <= tt.consecutiveFailures { + if hj, ok := w.(http.Hijacker); ok { + conn, _, err := hj.Hijack() + if err != nil { + fmt.Printf("Hijacking failed: %v\n", err) + return + } + conn.Close() // Close the connection to simulate a reset + } + return + } + + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"type": "object"}`)) + + case "/503": + s := http.StatusServiceUnavailable + if callCount >= tt.consecutiveFailures { + s = http.StatusOK + } + w.WriteHeader(s) + w.Write([]byte(`{"type": "object"}`)) + } + })) + defer server.Close() + + // Create HTTPURLLoader + loader, _ := NewHTTPURLLoader(false, nil) + + fullurl := server.URL + tt.url + // Call Load and handle errors + _, err := loader.Load(fullurl) + if tt.expectError && err == nil { + t.Error("expected error, got none") + } + if !tt.expectError && err != nil { + t.Errorf("expected no error, got %v", err) + } + ccMutex.Lock() + if callCounts[tt.url] != tt.expectCallCount { + t.Errorf("expected %d calls, got: %d", tt.expectCallCount, callCounts[tt.url]) + } + ccMutex.Unlock() + }) + } +} diff --git a/pkg/registry/http_test.go b/pkg/registry/http_test.go deleted file mode 100644 index d82d2fd..0000000 --- a/pkg/registry/http_test.go +++ /dev/null @@ -1,175 +0,0 @@ -package registry - -import ( - "net/http" -) - -type mockHTTPGetter struct { - callNumber int - httpGet func(mockHTTPGetter, string) (*http.Response, error) -} - -func newMockHTTPGetter(f func(mockHTTPGetter, string) (*http.Response, error)) *mockHTTPGetter { - return &mockHTTPGetter{ - callNumber: 0, - httpGet: f, - } -} -func (m *mockHTTPGetter) Get(url string) (resp *http.Response, err error) { - m.callNumber = m.callNumber + 1 - return m.httpGet(*m, url) -} - -//func TestDownloadSchema(t *testing.T) { -// callCounts := map[string]int{} -// -// // http server to simulate different responses -// http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { -// var s int -// callCounts[r.URL.Path]++ -// callCount := callCounts[r.URL.Path] -// -// switch r.URL.Path { -// case "/404": -// s = http.StatusNotFound -// case "/500": -// s = http.StatusInternalServerError -// case "/503": -// if callCount < 2 { -// s = http.StatusServiceUnavailable -// } else { -// s = http.StatusOK // Should succeed on 3rd try -// } -// -// case "/simulate-reset": -// if callCount < 2 { -// if hj, ok := w.(http.Hijacker); ok { -// conn, _, err := hj.Hijack() -// if err != nil { -// fmt.Printf("Hijacking failed: %v\n", err) -// return -// } -// conn.Close() // Close the connection to simulate a reset -// } -// return -// } -// s = http.StatusOK // Should succeed on third try -// -// default: -// s = http.StatusOK -// } -// -// w.WriteHeader(s) -// w.Write([]byte(http.StatusText(s))) -// }) -// -// port := fmt.Sprint(rand.Intn(1000) + 9000) // random port -// server := &http.Server{Addr: "127.0.0.1:" + port} -// url := fmt.Sprintf("http://localhost:%s", port) -// -// go func() { -// if err := server.ListenAndServe(); err != nil { -// fmt.Printf("Failed to start server: %v\n", err) -// } -// }() -// defer server.Shutdown(nil) -// -// // Wait for the server to start -// for i := 0; i < 20; i++ { -// if _, err := http.Get(url); err == nil { -// break -// } -// -// if i == 19 { -// t.Error("http server did not start") -// return -// } -// -// time.Sleep(50 * time.Millisecond) -// } -// -// for _, testCase := range []struct { -// name string -// schemaPathTemplate string -// strict bool -// resourceKind, resourceAPIVersion, k8sversion string -// expect []byte -// expectErr error -// }{ -// { -// "retry connection reset by peer", -// fmt.Sprintf("%s/simulate-reset", url), -// true, -// "Deployment", -// "v1", -// "1.18.0", -// []byte(http.StatusText(http.StatusOK)), -// nil, -// }, -// { -// "getting 404", -// fmt.Sprintf("%s/404", url), -// true, -// "Deployment", -// "v1", -// "1.18.0", -// nil, -// fmt.Errorf("could not find schema at %s/404", url), -// }, -// { -// "getting 500", -// fmt.Sprintf("%s/500", url), -// true, -// "Deployment", -// "v1", -// "1.18.0", -// nil, -// fmt.Errorf("failed downloading schema at %s/500: Get \"%s/500\": GET %s/500 giving up after 3 attempt(s)", url, url, url), -// }, -// { -// "retry 503", -// fmt.Sprintf("%s/503", url), -// true, -// "Deployment", -// "v1", -// "1.18.0", -// []byte(http.StatusText(http.StatusOK)), -// nil, -// }, -// { -// "200", -// url, -// true, -// "Deployment", -// "v1", -// "1.18.0", -// []byte(http.StatusText(http.StatusOK)), -// nil, -// }, -// } { -// callCounts = map[string]int{} // Reinitialise counters -// -// reg, err := newHTTPRegistry(testCase.schemaPathTemplate, "", testCase.strict, true, true) -// if err != nil { -// t.Errorf("during test '%s': failed to create registry: %s", testCase.name, err) -// continue -// } -// -// _, res, err := reg.DownloadSchema(testCase.resourceKind, testCase.resourceAPIVersion, testCase.k8sversion) -// if err == nil || testCase.expectErr == nil { -// if err == nil && testCase.expectErr != nil { -// t.Errorf("during test '%s': expected error\n%s, got nil", testCase.name, testCase.expectErr) -// } -// if err != nil && testCase.expectErr == nil { -// t.Errorf("during test '%s': expected no error, got\n%s\n", testCase.name, err) -// } -// } else if err.Error() != testCase.expectErr.Error() { -// t.Errorf("during test '%s': expected error\n%s, got:\n%s\n", testCase.name, testCase.expectErr, err) -// } -// -// if !bytes.Equal(res, testCase.expect) { -// t.Errorf("during test '%s': expected '%s', got '%s'", testCase.name, testCase.expect, res) -// } -// } -// -//} diff --git a/pkg/registry/registry.go b/pkg/registry/registry.go index 9584c17..dd965c3 100644 --- a/pkg/registry/registry.go +++ b/pkg/registry/registry.go @@ -91,8 +91,6 @@ func New(schemaLocation string, cacheFolder string, strict bool, skipTLS bool, d } c = cache.NewOnDiskCache(cacheFolder) - } else { - c = cache.NewInMemoryCache() } if strings.HasPrefix(schemaLocation, "http") { @@ -103,6 +101,6 @@ func New(schemaLocation string, cacheFolder string, strict bool, skipTLS bool, d return newHTTPRegistry(schemaLocation, httpLoader, strict, debug) } - fileLoader := loader.NewFileLoader(c) + fileLoader := loader.NewFileLoader() return newLocalRegistry(schemaLocation, fileLoader, strict, debug) }