diff --git a/observability/metrics/prometheus/server.go b/observability/metrics/prometheus/server.go index 9aa1db0c30..e467dc9ffa 100644 --- a/observability/metrics/prometheus/server.go +++ b/observability/metrics/prometheus/server.go @@ -18,6 +18,7 @@ package prometheus import ( "context" + "crypto/tls" "fmt" "net" "net/http" @@ -36,12 +37,16 @@ const ( defaultPrometheusHost = "" // IPv4 and IPv6 prometheusPortEnvName = "METRICS_PROMETHEUS_PORT" prometheusHostEnvName = "METRICS_PROMETHEUS_HOST" + prometheusTLSCertEnvName = "METRICS_TLS_CERT" + prometheusTLSKeyEnvName = "METRICS_TLS_KEY" ) type ServerOption func(*options) type Server struct { - http *http.Server + http *http.Server + certFile string + keyFile string } func NewServer(opts ...ServerOption) (*Server, error) { @@ -56,6 +61,8 @@ func NewServer(opts ...ServerOption) (*Server, error) { envOverride(&o.host, prometheusHostEnvName) envOverride(&o.port, prometheusPortEnvName) + envOverride(&o.certFile, prometheusTLSCertEnvName) + envOverride(&o.keyFile, prometheusTLSKeyEnvName) if err := validate(&o); err != nil { return nil, err @@ -68,16 +75,21 @@ func NewServer(opts ...ServerOption) (*Server, error) { return &Server{ http: &http.Server{ - Addr: addr, - Handler: mux, - // https://medium.com/a-journey-with-go/go-understand-and-mitigate-slowloris-attack-711c1b1403f6 + Addr: addr, + Handler: mux, + TLSConfig: o.tlsConfig, ReadHeaderTimeout: 5 * time.Second, }, + certFile: o.certFile, + keyFile: o.keyFile, }, nil } -func (s *Server) ListenAndServe() { - s.http.ListenAndServe() +func (s *Server) ListenAndServe() error { + if s.http.TLSConfig != nil || (s.certFile != "" && s.keyFile != "") { + return s.http.ListenAndServeTLS(s.certFile, s.keyFile) + } + return s.http.ListenAndServe() } func (s *Server) Shutdown(ctx context.Context) error { @@ -85,8 +97,11 @@ func (s *Server) Shutdown(ctx context.Context) error { } type options struct { - host string - port string + host string + port string + tlsConfig *tls.Config + certFile string + keyFile string } func WithHost(host string) ServerOption { @@ -101,6 +116,22 @@ func WithPort(port string) ServerOption { } } +// WithTLSConfig configures the server to use the provided TLS configuration. +// This allows programmatic control over TLS settings like MinVersion, CipherSuites, etc. +func WithTLSConfig(cfg *tls.Config) ServerOption { + return func(o *options) { + o.tlsConfig = cfg + } +} + +// WithTLSCertFiles configures the server to use TLS with the provided certificate and key files. +func WithTLSCertFiles(certFile, keyFile string) ServerOption { + return func(o *options) { + o.certFile = certFile + o.keyFile = keyFile + } +} + func validate(o *options) error { port, err := strconv.ParseUint(o.port, 10, 16) if err != nil { @@ -122,3 +153,21 @@ func envOverride(target *string, envName string) { *target = val } } + +type tlsConfigKey struct{} + +// ContextWithTLSConfig adds a TLS configuration to the context. +// This allows programmatic configuration of TLS settings like MinVersion, CipherSuites, ClientAuth, etc. +// when using frameworks like sharedmain where direct NewServer() options aren't accessible. +func ContextWithTLSConfig(ctx context.Context, cfg *tls.Config) context.Context { + return context.WithValue(ctx, tlsConfigKey{}, cfg) +} + +// TLSConfigFromContext retrieves the TLS configuration from the context. +// Returns nil if no TLS configuration was set. +func TLSConfigFromContext(ctx context.Context) *tls.Config { + if cfg, ok := ctx.Value(tlsConfigKey{}).(*tls.Config); ok { + return cfg + } + return nil +} diff --git a/observability/metrics/prometheus/server_test.go b/observability/metrics/prometheus/server_test.go index 75da9e112c..a11e923d6b 100644 --- a/observability/metrics/prometheus/server_test.go +++ b/observability/metrics/prometheus/server_test.go @@ -17,6 +17,8 @@ limitations under the License. package prometheus import ( + "context" + "crypto/tls" "testing" "github.com/google/go-cmp/cmp" @@ -72,3 +74,143 @@ func TestNewServerFailure(t *testing.T) { t.Error("expected above port range to fail") } } + +func TestNewServerWithTLSConfig(t *testing.T) { + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS13, + MaxVersion: tls.VersionTLS13, + } + + s, err := NewServer(WithTLSConfig(tlsConfig)) + if err != nil { + t.Fatal("NewServer() =", err) + } + + if s.http.TLSConfig == nil { + t.Error("expected TLSConfig to be set on http.Server") + } + + if s.http.TLSConfig.MinVersion != tls.VersionTLS13 { + t.Errorf("expected MinVersion to be TLS 1.3, got %v", s.http.TLSConfig.MinVersion) + } +} + +func TestNewServerWithTLSCertFiles(t *testing.T) { + s, err := NewServer(WithTLSCertFiles("/path/to/cert.pem", "/path/to/key.pem")) + if err != nil { + t.Fatal("NewServer() =", err) + } + + if s.certFile != "/path/to/cert.pem" { + t.Errorf("expected certFile to be /path/to/cert.pem, got %s", s.certFile) + } + + if s.keyFile != "/path/to/key.pem" { + t.Errorf("expected keyFile to be /path/to/key.pem, got %s", s.keyFile) + } +} + +func TestNewServerWithTLSEnvVars(t *testing.T) { + t.Setenv(prometheusTLSCertEnvName, "/etc/tls/tls.crt") + t.Setenv(prometheusTLSKeyEnvName, "/etc/tls/tls.key") + + s, err := NewServer() + if err != nil { + t.Fatal("NewServer() =", err) + } + + if s.certFile != "/etc/tls/tls.crt" { + t.Errorf("expected certFile to be /etc/tls/tls.crt, got %s", s.certFile) + } + + if s.keyFile != "/etc/tls/tls.key" { + t.Errorf("expected keyFile to be /etc/tls/tls.key, got %s", s.keyFile) + } +} + +func TestNewServerTLSEnvOverridesOption(t *testing.T) { + t.Setenv(prometheusTLSCertEnvName, "/env/cert.pem") + t.Setenv(prometheusTLSKeyEnvName, "/env/key.pem") + + s, err := NewServer(WithTLSCertFiles("/opt/cert.pem", "/opt/key.pem")) + if err != nil { + t.Fatal("NewServer() =", err) + } + + if s.certFile != "/env/cert.pem" { + t.Errorf("expected env var to override option, got certFile=%s", s.certFile) + } + + if s.keyFile != "/env/key.pem" { + t.Errorf("expected env var to override option, got keyFile=%s", s.keyFile) + } +} + +func TestContextWithTLSConfig(t *testing.T) { + ctx := context.Background() + + if cfg := TLSConfigFromContext(ctx); cfg != nil { + t.Error("expected TLSConfigFromContext to return nil for empty context") + } + + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + } + + ctx = ContextWithTLSConfig(ctx, tlsConfig) + + retrieved := TLSConfigFromContext(ctx) + if retrieved == nil { + t.Fatal("expected TLSConfigFromContext to return the config") + } + + if retrieved.MinVersion != tls.VersionTLS12 { + t.Errorf("expected MinVersion to be TLS 1.2, got %v", retrieved.MinVersion) + } +} + +func TestTLSConfigWithCertFiles(t *testing.T) { + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS13, + MaxVersion: tls.VersionTLS13, + CipherSuites: []uint16{ + tls.TLS_AES_256_GCM_SHA384, + tls.TLS_CHACHA20_POLY1305_SHA256, + }, + } + + s, err := NewServer( + WithTLSConfig(tlsConfig), + WithTLSCertFiles("/etc/tls/tls.crt", "/etc/tls/tls.key"), + ) + if err != nil { + t.Fatal("NewServer() =", err) + } + + // Verify TLSConfig is set with the specified settings + if s.http.TLSConfig == nil { + t.Fatal("expected TLSConfig to be set") + } + + if s.http.TLSConfig.MinVersion != tls.VersionTLS13 { + t.Errorf("expected MinVersion TLS 1.3, got %v", s.http.TLSConfig.MinVersion) + } + + if s.http.TLSConfig.MaxVersion != tls.VersionTLS13 { + t.Errorf("expected MaxVersion TLS 1.3, got %v", s.http.TLSConfig.MaxVersion) + } + + if len(s.http.TLSConfig.CipherSuites) != 2 { + t.Errorf("expected 2 cipher suites, got %d", len(s.http.TLSConfig.CipherSuites)) + } + + // Verify cert files are also set + if s.certFile != "/etc/tls/tls.crt" { + t.Errorf("expected certFile=/etc/tls/tls.crt, got %s", s.certFile) + } + + if s.keyFile != "/etc/tls/tls.key" { + t.Errorf("expected keyFile=/etc/tls/tls.key, got %s", s.keyFile) + } + +} diff --git a/observability/metrics/prometheus_enabled.go b/observability/metrics/prometheus_enabled.go index ed4082574f..d09f56a703 100644 --- a/observability/metrics/prometheus_enabled.go +++ b/observability/metrics/prometheus_enabled.go @@ -29,7 +29,7 @@ import ( "knative.dev/pkg/observability/metrics/prometheus" ) -func buildPrometheus(_ context.Context, cfg Config) (sdkmetric.Reader, shutdownFunc, error) { +func buildPrometheus(ctx context.Context, cfg Config) (sdkmetric.Reader, shutdownFunc, error) { r, err := otelprom.New( otelprom.WithTranslationStrategy(otlptranslator.UnderscoreEscapingWithSuffixes), ) @@ -54,10 +54,17 @@ func buildPrometheus(_ context.Context, cfg Config) (sdkmetric.Reader, shutdownF ) } + // Check for TLSConfig in context (highest priority) + if tlsConfig := prometheus.TLSConfigFromContext(ctx); tlsConfig != nil { + opts = append(opts, prometheus.WithTLSConfig(tlsConfig)) + } + server, err := prometheus.NewServer(opts...) go func() { - server.ListenAndServe() + if err := server.ListenAndServe(); err != nil { + fmt.Printf("metrics server error: %v\n", err) + } }() return r, server.Shutdown, err