diff --git a/pkg/cmd/scylla-manager/server.go b/pkg/cmd/scylla-manager/server.go index 447ce28714..ee2e2001de 100644 --- a/pkg/cmd/scylla-manager/server.go +++ b/pkg/cmd/scylla-manager/server.go @@ -81,6 +81,7 @@ func (s *server) makeServices() error { s.config.Healthcheck, s.clusterSvc.Client, secretsStore, + s.clusterSvc.GetClusterByID, s.logger.Named("healthcheck"), ) if err != nil { diff --git a/pkg/service/cluster/service.go b/pkg/service/cluster/service.go index 3e5aed0b0c..8182ad501b 100644 --- a/pkg/service/cluster/service.go +++ b/pkg/service/cluster/service.go @@ -25,6 +25,9 @@ import ( "go.uber.org/multierr" ) +// ProviderFunc defines the function that will be used by other services to get current cluster data. +type ProviderFunc func(ctx context.Context, id uuid.UUID) (*Cluster, error) + // ChangeType specifies type on Change. type ChangeType int8 diff --git a/pkg/service/healthcheck/service.go b/pkg/service/healthcheck/service.go index 89f0a5d943..dac78cc9fb 100644 --- a/pkg/service/healthcheck/service.go +++ b/pkg/service/healthcheck/service.go @@ -12,6 +12,7 @@ import ( "github.com/pkg/errors" "github.com/scylladb/go-log" + "github.com/scylladb/scylla-manager/v3/pkg/service/cluster" "golang.org/x/sync/errgroup" "github.com/scylladb/scylla-manager/v3/pkg/ping" @@ -51,17 +52,23 @@ func (pt pingType) String() string { return "unknown" } +type tlsConfigWithAddress struct { + *tls.Config + Address string +} + type nodeInfo struct { *scyllaclient.NodeInfo - TLSConfig map[pingType]*tls.Config + TLSConfig map[pingType]*tlsConfigWithAddress Expires time.Time } // Service manages health checks. type Service struct { - config Config - scyllaClient scyllaclient.ProviderFunc - secretsStore store.Store + config Config + scyllaClient scyllaclient.ProviderFunc + secretsStore store.Store + clusterProvider cluster.ProviderFunc cacheMu sync.Mutex // fields below are protected by cacheMu @@ -70,17 +77,20 @@ type Service struct { logger log.Logger } -func NewService(config Config, scyllaClient scyllaclient.ProviderFunc, secretsStore store.Store, logger log.Logger) (*Service, error) { +func NewService(config Config, scyllaClient scyllaclient.ProviderFunc, secretsStore store.Store, + clusterProvider cluster.ProviderFunc, logger log.Logger, +) (*Service, error) { if scyllaClient == nil { return nil, errors.New("invalid scylla provider") } return &Service{ - config: config, - scyllaClient: scyllaClient, - secretsStore: secretsStore, - nodeInfoCache: make(map[clusterIDHost]nodeInfo), - logger: logger, + config: config, + scyllaClient: scyllaClient, + secretsStore: secretsStore, + clusterProvider: clusterProvider, + nodeInfoCache: make(map[clusterIDHost]nodeInfo), + logger: logger, }, nil } @@ -354,7 +364,7 @@ func (s *Service) pingCQL(ctx context.Context, clusterID uuid.UUID, host string, tlsConfig := ni.tlsConfig(cqlPing) if tlsConfig != nil { - config.Addr = ni.CQLSSLAddr(host) + config.Addr = tlsConfig.Address config.TLSConfig = tlsConfig.Clone() } @@ -399,6 +409,12 @@ func (s *Service) nodeInfo(ctx context.Context, clusterID uuid.UUID, host string if ni, ok := s.nodeInfoCache[key]; ok && now.Before(ni.Expires) { return ni, nil } + + c, err := s.clusterProvider(ctx, clusterID) + if err != nil { + return nodeInfo{}, nil + } + client, err := s.scyllaClient(ctx, clusterID) if err != nil { return nodeInfo{}, errors.Wrap(err, "create scylla client") @@ -409,13 +425,20 @@ func (s *Service) nodeInfo(ctx context.Context, clusterID uuid.UUID, host string return nodeInfo{}, errors.Wrap(err, "fetch node info") } - ni.TLSConfig = make(map[pingType]*tls.Config, 2) + ni.TLSConfig = make(map[pingType]*tlsConfigWithAddress, 2) for _, p := range []pingType{alternatorPing, cqlPing} { var tlsEnabled, clientCertAuth bool + var address string if p == cqlPing { + address = ni.CQLAddr(host) tlsEnabled, clientCertAuth = ni.CQLTLSEnabled() + tlsEnabled = tlsEnabled && !c.ForceTLSDisabled + if tlsEnabled && !c.ForceNonSSLSessionPort { + address = ni.CQLSSLAddr(host) + } } else if p == alternatorPing { tlsEnabled, clientCertAuth = ni.AlternatorTLSEnabled() + address = ni.AlternatorAddr(host) } if tlsEnabled { tlsConfig, err := s.tlsConfig(clusterID, clientCertAuth) @@ -427,7 +450,10 @@ func (s *Service) nodeInfo(ctx context.Context, clusterID uuid.UUID, host string "cluster_id", clusterID, ) } else { - ni.TLSConfig[p] = tlsConfig + ni.TLSConfig[p] = &tlsConfigWithAddress{ + Config: tlsConfig, + Address: address, + } } } } @@ -439,7 +465,7 @@ func (s *Service) nodeInfo(ctx context.Context, clusterID uuid.UUID, host string } func (s *Service) tlsConfig(clusterID uuid.UUID, clientCertAuth bool) (*tls.Config, error) { - cfg := &tls.Config{ + cfg := tls.Config{ InsecureSkipVerify: true, } @@ -457,7 +483,7 @@ func (s *Service) tlsConfig(clusterID uuid.UUID, clientCertAuth bool) (*tls.Conf cfg.Certificates = []tls.Certificate{keyPair} } - return cfg, nil + return &cfg, nil } func (s *Service) cqlCreds(ctx context.Context, clusterID uuid.UUID) *secrets.CQLCreds { @@ -486,7 +512,7 @@ func (s *Service) InvalidateCache(clusterID uuid.UUID) { s.cacheMu.Unlock() } -func (ni nodeInfo) tlsConfig(pt pingType) *tls.Config { +func (ni nodeInfo) tlsConfig(pt pingType) *tlsConfigWithAddress { return ni.TLSConfig[pt] } diff --git a/pkg/service/healthcheck/service_integration_test.go b/pkg/service/healthcheck/service_integration_test.go index 3a8bc650ef..22374ba86f 100644 --- a/pkg/service/healthcheck/service_integration_test.go +++ b/pkg/service/healthcheck/service_integration_test.go @@ -18,6 +18,8 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/scylladb/go-log" + "github.com/scylladb/scylla-manager/v3/pkg/metrics" + "github.com/scylladb/scylla-manager/v3/pkg/service/cluster" "go.uber.org/zap/zapcore" "github.com/scylladb/scylla-manager/v3/pkg/schema/table" @@ -41,9 +43,12 @@ func TestStatusIntegration(t *testing.T) { defer session.Close() clusterID := uuid.MustRandom() - s := store.NewTableStore(session, table.Secrets) - testStatusIntegration(t, clusterID, s) + clusterSvc, err := cluster.NewService(session, metrics.NewClusterMetrics(), s, scyllaclient.DefaultTimeoutConfig(), log.NewDevelopment()) + if err != nil { + t.Fatal(err) + } + testStatusIntegration(t, clusterID, clusterSvc.GetClusterByID, s) } func TestStatusWithCQLCredentialsIntegration(t *testing.T) { @@ -65,11 +70,15 @@ func TestStatusWithCQLCredentialsIntegration(t *testing.T) { }); err != nil { t.Fatal(err) } + clusterSvc, err := cluster.NewService(session, metrics.NewClusterMetrics(), s, scyllaclient.DefaultTimeoutConfig(), log.NewDevelopment()) + if err != nil { + t.Fatal(err) + } - testStatusIntegration(t, clusterID, s) + testStatusIntegration(t, clusterID, clusterSvc.GetClusterByID, s) } -func testStatusIntegration(t *testing.T, clusterID uuid.UUID, secretsStore store.Store) { +func testStatusIntegration(t *testing.T, clusterID uuid.UUID, clusterProvider cluster.ProviderFunc, secretsStore store.Store) { logger := log.NewDevelopmentWithLevel(zapcore.InfoLevel).Named("healthcheck") // Tests here do not test the dynamic t/o functionality @@ -97,6 +106,7 @@ func testStatusIntegration(t *testing.T, clusterID uuid.UUID, secretsStore store return scyllaclient.NewClient(sc, logger.Named("scylla")) }, secretsStore, + clusterProvider, logger, ) if err != nil {