Skip to content

Commit 69e2e45

Browse files
Introduce tests for the internal package
1 parent cb7be38 commit 69e2e45

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed

internal/internal_test.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// Copyright 2024 Google LLC
2+
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package internal_test
16+
17+
import (
18+
"fmt"
19+
"net/http"
20+
"net/http/httptest"
21+
"testing"
22+
23+
"github.com/google-gemini/proxy-to-gemini/internal"
24+
)
25+
26+
func TestErrorHandler(t *testing.T) {
27+
tests := []struct {
28+
name string
29+
method string
30+
code int
31+
msg string
32+
arg []interface{}
33+
wantBody string
34+
wantLog string
35+
}{
36+
{
37+
name: "Bad request without args",
38+
method: http.MethodPost,
39+
code: http.StatusBadRequest,
40+
msg: "failed to read request body",
41+
wantBody: "failed to read request body\n",
42+
},
43+
{
44+
name: "Internal server error with args",
45+
method: http.MethodGet,
46+
code: http.StatusInternalServerError,
47+
msg: "failed to generate content: %v",
48+
arg: []interface{}{fmt.Errorf("generic error")},
49+
wantBody: "failed to generate content: generic error\n",
50+
},
51+
}
52+
53+
for _, tt := range tests {
54+
t.Run(tt.name, func(t *testing.T) {
55+
recorder := httptest.NewRecorder()
56+
57+
req := httptest.NewRequest(tt.method, "/", nil)
58+
59+
internal.ErrorHandler(recorder, req, tt.code, tt.msg, tt.arg...)
60+
61+
if recorder.Code != tt.code {
62+
t.Errorf("got status %v, want %v", recorder.Code, tt.code)
63+
}
64+
65+
if recorder.Body.String() != tt.wantBody {
66+
t.Errorf("got body %v, want %v", recorder.Body.String(), tt.wantBody)
67+
}
68+
})
69+
}
70+
}

0 commit comments

Comments
 (0)