Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,17 @@ func RequireBearerToken(verifier TokenVerifier, opts *RequireBearerTokenOptions)
tokenInfo, errmsg, code := verify(r, verifier, opts)
if code != 0 {
if code == http.StatusUnauthorized || code == http.StatusForbidden {
if opts != nil && opts.ResourceMetadataURL != "" {
w.Header().Add("WWW-Authenticate", "Bearer resource_metadata="+opts.ResourceMetadataURL)
if opts != nil {
var params []string
if opts.ResourceMetadataURL != "" {
params = append(params, "resource_metadata=\""+opts.ResourceMetadataURL+"\"")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is clearer: fmt.Sprintf("resource_metadata=%q", opts....)

and ditto below

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, for some reason I mixed it up and thought %q produces single quotes. Will adjust the code.

}
if len(opts.Scopes) > 0 {
params = append(params, "scope=\""+strings.Join(opts.Scopes, " ")+"\"")
}
if len(params) > 0 {
w.Header().Add("WWW-Authenticate", "Bearer "+strings.Join(params, ", "))
}
}
}
http.Error(w, errmsg, code)
Expand Down
96 changes: 96 additions & 0 deletions auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,99 @@ func TestProtectedResourceMetadataHandler(t *testing.T) {
})
}
}

func TestRequireBearerToken(t *testing.T) {
verifier := func(_ context.Context, token string, _ *http.Request) (*TokenInfo, error) {
if token == "valid" {
return &TokenInfo{Expiration: time.Now().Add(time.Hour), Scopes: []string{"read"}}, nil
}
return nil, ErrInvalidToken
}

tests := []struct {
name string
opts *RequireBearerTokenOptions
authHeader string
wantHeader string
wantStatus int
}{
{
name: "no middleware options",
opts: nil,
authHeader: "Bearer invalid",
wantHeader: "",
wantStatus: http.StatusUnauthorized,
},
{
name: "metadata only",
opts: &RequireBearerTokenOptions{
ResourceMetadataURL: "https://example.com/resource-metadata",
},
authHeader: "Bearer invalid",
wantHeader: "Bearer resource_metadata=\"https://example.com/resource-metadata\"",
wantStatus: http.StatusUnauthorized,
},
{
name: "scopes only",
opts: &RequireBearerTokenOptions{
Scopes: []string{"read", "write"},
},
authHeader: "Bearer invalid",
wantHeader: "Bearer scope=\"read write\"",
wantStatus: http.StatusUnauthorized,
},
{
name: "metadata and scopes",
opts: &RequireBearerTokenOptions{
ResourceMetadataURL: "https://example.com/resource-metadata",
Scopes: []string{"read", "write"},
},
authHeader: "Bearer invalid",
wantHeader: "Bearer resource_metadata=\"https://example.com/resource-metadata\", scope=\"read write\"",
wantStatus: http.StatusUnauthorized,
},
{
name: "insufficient scope",
opts: &RequireBearerTokenOptions{
Scopes: []string{"admin"},
},
authHeader: "Bearer valid", // Has "read", needs "admin" -> 403
wantHeader: "Bearer scope=\"admin\"",
wantStatus: http.StatusForbidden,
},
{
name: "success",
opts: &RequireBearerTokenOptions{
Scopes: []string{"read"},
},
authHeader: "Bearer valid",
wantHeader: "",
wantStatus: http.StatusOK,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler := RequireBearerToken(verifier, tt.opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))

req := httptest.NewRequest("GET", "/", nil)
if tt.authHeader != "" {
req.Header.Set("Authorization", tt.authHeader)
}
rec := httptest.NewRecorder()

handler.ServeHTTP(rec, req)

if rec.Code != tt.wantStatus {
t.Errorf("status = %d, want %d", rec.Code, tt.wantStatus)
}

got := rec.Header().Get("WWW-Authenticate")
if got != tt.wantHeader {
t.Errorf("WWW-Authenticate = %q, want %q", got, tt.wantHeader)
}
})
}
}
Loading