Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 57 additions & 8 deletions observability/metrics/prometheus/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package prometheus

import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
Expand All @@ -36,12 +37,16 @@ const (
defaultPrometheusHost = "" // IPv4 and IPv6
prometheusPortEnvName = "METRICS_PROMETHEUS_PORT"
prometheusHostEnvName = "METRICS_PROMETHEUS_HOST"
prometheusTLSCertEnvName = "METRICS_TLS_CERT"
prometheusTLSKeyEnvName = "METRICS_TLS_KEY"
)

type ServerOption func(*options)

type Server struct {
http *http.Server
http *http.Server
certFile string
keyFile string
}

func NewServer(opts ...ServerOption) (*Server, error) {
Expand All @@ -56,6 +61,8 @@ func NewServer(opts ...ServerOption) (*Server, error) {

envOverride(&o.host, prometheusHostEnvName)
envOverride(&o.port, prometheusPortEnvName)
envOverride(&o.certFile, prometheusTLSCertEnvName)
envOverride(&o.keyFile, prometheusTLSKeyEnvName)

if err := validate(&o); err != nil {
return nil, err
Expand All @@ -68,25 +75,33 @@ func NewServer(opts ...ServerOption) (*Server, error) {

return &Server{
http: &http.Server{
Addr: addr,
Handler: mux,
// https://medium.com/a-journey-with-go/go-understand-and-mitigate-slowloris-attack-711c1b1403f6
Addr: addr,
Handler: mux,
TLSConfig: o.tlsConfig,
ReadHeaderTimeout: 5 * time.Second,
},
certFile: o.certFile,
keyFile: o.keyFile,
}, nil
}

func (s *Server) ListenAndServe() {
s.http.ListenAndServe()
func (s *Server) ListenAndServe() error {
if s.http.TLSConfig != nil || (s.certFile != "" && s.keyFile != "") {
return s.http.ListenAndServeTLS(s.certFile, s.keyFile)
}
return s.http.ListenAndServe()
}

func (s *Server) Shutdown(ctx context.Context) error {
return s.http.Shutdown(ctx)
}

type options struct {
host string
port string
host string
port string
tlsConfig *tls.Config
certFile string
keyFile string
}

func WithHost(host string) ServerOption {
Expand All @@ -101,6 +116,22 @@ func WithPort(port string) ServerOption {
}
}

// WithTLSConfig configures the server to use the provided TLS configuration.
// This allows programmatic control over TLS settings like MinVersion, CipherSuites, etc.
func WithTLSConfig(cfg *tls.Config) ServerOption {
return func(o *options) {
o.tlsConfig = cfg
}
}

// WithTLSCertFiles configures the server to use TLS with the provided certificate and key files.
func WithTLSCertFiles(certFile, keyFile string) ServerOption {
return func(o *options) {
o.certFile = certFile
o.keyFile = keyFile
}
}

func validate(o *options) error {
port, err := strconv.ParseUint(o.port, 10, 16)
if err != nil {
Expand All @@ -122,3 +153,21 @@ func envOverride(target *string, envName string) {
*target = val
}
}

type tlsConfigKey struct{}

// ContextWithTLSConfig adds a TLS configuration to the context.
// This allows programmatic configuration of TLS settings like MinVersion, CipherSuites, ClientAuth, etc.
// when using frameworks like sharedmain where direct NewServer() options aren't accessible.
func ContextWithTLSConfig(ctx context.Context, cfg *tls.Config) context.Context {
return context.WithValue(ctx, tlsConfigKey{}, cfg)
}

// TLSConfigFromContext retrieves the TLS configuration from the context.
// Returns nil if no TLS configuration was set.
func TLSConfigFromContext(ctx context.Context) *tls.Config {
if cfg, ok := ctx.Value(tlsConfigKey{}).(*tls.Config); ok {
return cfg
}
return nil
}
142 changes: 142 additions & 0 deletions observability/metrics/prometheus/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package prometheus

import (
"context"
"crypto/tls"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -72,3 +74,143 @@
t.Error("expected above port range to fail")
}
}

func TestNewServerWithTLSConfig(t *testing.T) {
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS13,
MaxVersion: tls.VersionTLS13,
}

s, err := NewServer(WithTLSConfig(tlsConfig))
if err != nil {
t.Fatal("NewServer() =", err)
}

if s.http.TLSConfig == nil {
t.Error("expected TLSConfig to be set on http.Server")
}

if s.http.TLSConfig.MinVersion != tls.VersionTLS13 {
t.Errorf("expected MinVersion to be TLS 1.3, got %v", s.http.TLSConfig.MinVersion)
}
}

func TestNewServerWithTLSCertFiles(t *testing.T) {
s, err := NewServer(WithTLSCertFiles("/path/to/cert.pem", "/path/to/key.pem"))
if err != nil {
t.Fatal("NewServer() =", err)
}

if s.certFile != "/path/to/cert.pem" {
t.Errorf("expected certFile to be /path/to/cert.pem, got %s", s.certFile)
}

if s.keyFile != "/path/to/key.pem" {
t.Errorf("expected keyFile to be /path/to/key.pem, got %s", s.keyFile)
}
}

func TestNewServerWithTLSEnvVars(t *testing.T) {
t.Setenv(prometheusTLSCertEnvName, "/etc/tls/tls.crt")
t.Setenv(prometheusTLSKeyEnvName, "/etc/tls/tls.key")

s, err := NewServer()
if err != nil {
t.Fatal("NewServer() =", err)
}

if s.certFile != "/etc/tls/tls.crt" {
t.Errorf("expected certFile to be /etc/tls/tls.crt, got %s", s.certFile)
}

if s.keyFile != "/etc/tls/tls.key" {
t.Errorf("expected keyFile to be /etc/tls/tls.key, got %s", s.keyFile)
}
}

func TestNewServerTLSEnvOverridesOption(t *testing.T) {
t.Setenv(prometheusTLSCertEnvName, "/env/cert.pem")
t.Setenv(prometheusTLSKeyEnvName, "/env/key.pem")

s, err := NewServer(WithTLSCertFiles("/opt/cert.pem", "/opt/key.pem"))
if err != nil {
t.Fatal("NewServer() =", err)
}

if s.certFile != "/env/cert.pem" {
t.Errorf("expected env var to override option, got certFile=%s", s.certFile)
}

if s.keyFile != "/env/key.pem" {
t.Errorf("expected env var to override option, got keyFile=%s", s.keyFile)
}
}

func TestContextWithTLSConfig(t *testing.T) {
ctx := context.Background()

if cfg := TLSConfigFromContext(ctx); cfg != nil {
t.Error("expected TLSConfigFromContext to return nil for empty context")
}

tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
}

ctx = ContextWithTLSConfig(ctx, tlsConfig)

retrieved := TLSConfigFromContext(ctx)
if retrieved == nil {
t.Fatal("expected TLSConfigFromContext to return the config")
}

if retrieved.MinVersion != tls.VersionTLS12 {
t.Errorf("expected MinVersion to be TLS 1.2, got %v", retrieved.MinVersion)
}
}

func TestTLSConfigWithCertFiles(t *testing.T) {
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS13,
MaxVersion: tls.VersionTLS13,
CipherSuites: []uint16{
tls.TLS_AES_256_GCM_SHA384,
tls.TLS_CHACHA20_POLY1305_SHA256,
},
}

s, err := NewServer(
WithTLSConfig(tlsConfig),
WithTLSCertFiles("/etc/tls/tls.crt", "/etc/tls/tls.key"),
)
if err != nil {
t.Fatal("NewServer() =", err)
}

// Verify TLSConfig is set with the specified settings
if s.http.TLSConfig == nil {
t.Fatal("expected TLSConfig to be set")
}

if s.http.TLSConfig.MinVersion != tls.VersionTLS13 {
t.Errorf("expected MinVersion TLS 1.3, got %v", s.http.TLSConfig.MinVersion)
}

if s.http.TLSConfig.MaxVersion != tls.VersionTLS13 {
t.Errorf("expected MaxVersion TLS 1.3, got %v", s.http.TLSConfig.MaxVersion)
}

if len(s.http.TLSConfig.CipherSuites) != 2 {
t.Errorf("expected 2 cipher suites, got %d", len(s.http.TLSConfig.CipherSuites))
}

// Verify cert files are also set
if s.certFile != "/etc/tls/tls.crt" {
t.Errorf("expected certFile=/etc/tls/tls.crt, got %s", s.certFile)
}

if s.keyFile != "/etc/tls/tls.key" {
t.Errorf("expected keyFile=/etc/tls/tls.key, got %s", s.keyFile)
}

Check failure on line 215 in observability/metrics/prometheus/server_test.go

View workflow job for this annotation

GitHub Actions / style / Golang / Lint

File is not properly formatted (gofumpt)
}

Check failure on line 216 in observability/metrics/prometheus/server_test.go

View workflow job for this annotation

GitHub Actions / style / Golang / Lint

unnecessary trailing newline (whitespace)
11 changes: 9 additions & 2 deletions observability/metrics/prometheus_enabled.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (
"knative.dev/pkg/observability/metrics/prometheus"
)

func buildPrometheus(_ context.Context, cfg Config) (sdkmetric.Reader, shutdownFunc, error) {
func buildPrometheus(ctx context.Context, cfg Config) (sdkmetric.Reader, shutdownFunc, error) {
r, err := otelprom.New(
otelprom.WithTranslationStrategy(otlptranslator.UnderscoreEscapingWithSuffixes),
)
Expand All @@ -54,10 +54,17 @@ func buildPrometheus(_ context.Context, cfg Config) (sdkmetric.Reader, shutdownF
)
}

// Check for TLSConfig in context (highest priority)
if tlsConfig := prometheus.TLSConfigFromContext(ctx); tlsConfig != nil {
opts = append(opts, prometheus.WithTLSConfig(tlsConfig))
}

server, err := prometheus.NewServer(opts...)

go func() {
server.ListenAndServe()
if err := server.ListenAndServe(); err != nil {
fmt.Printf("metrics server error: %v\n", err)
}
}()

return r, server.Shutdown, err
Expand Down
Loading