Skip to content

Commit

Permalink
fix(healtcheck): respect force-tls-disabled and force-non-ssl-session…
Browse files Browse the repository at this point in the history
…-port cluster properties
  • Loading branch information
karol-kokoszka committed Feb 14, 2024
1 parent d86640b commit 6a11ea1
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 20 deletions.
1 change: 1 addition & 0 deletions pkg/cmd/scylla-manager/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ func (s *server) makeServices() error {
s.config.Healthcheck,
s.clusterSvc.Client,
secretsStore,
s.clusterSvc.GetClusterByID,
s.logger.Named("healthcheck"),
)
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions pkg/service/cluster/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ import (
"go.uber.org/multierr"
)

// ProviderFunc defines the function that will be used by other services to get current cluster data.
type ProviderFunc func(ctx context.Context, id uuid.UUID) (*Cluster, error)

// ChangeType specifies type on Change.
type ChangeType int8

Expand Down
58 changes: 42 additions & 16 deletions pkg/service/healthcheck/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/pkg/errors"
"github.com/scylladb/go-log"
"github.com/scylladb/scylla-manager/v3/pkg/service/cluster"
"golang.org/x/sync/errgroup"

"github.com/scylladb/scylla-manager/v3/pkg/ping"
Expand Down Expand Up @@ -51,17 +52,23 @@ func (pt pingType) String() string {
return "unknown"
}

type tlsConfigWithAddress struct {
*tls.Config
Address string
}

type nodeInfo struct {
*scyllaclient.NodeInfo
TLSConfig map[pingType]*tls.Config
TLSConfig map[pingType]*tlsConfigWithAddress
Expires time.Time
}

// Service manages health checks.
type Service struct {
config Config
scyllaClient scyllaclient.ProviderFunc
secretsStore store.Store
config Config
scyllaClient scyllaclient.ProviderFunc
secretsStore store.Store
clusterProvider cluster.ProviderFunc

cacheMu sync.Mutex
// fields below are protected by cacheMu
Expand All @@ -70,17 +77,20 @@ type Service struct {
logger log.Logger
}

func NewService(config Config, scyllaClient scyllaclient.ProviderFunc, secretsStore store.Store, logger log.Logger) (*Service, error) {
func NewService(config Config, scyllaClient scyllaclient.ProviderFunc, secretsStore store.Store,
clusterProvider cluster.ProviderFunc, logger log.Logger,
) (*Service, error) {
if scyllaClient == nil {
return nil, errors.New("invalid scylla provider")
}

return &Service{
config: config,
scyllaClient: scyllaClient,
secretsStore: secretsStore,
nodeInfoCache: make(map[clusterIDHost]nodeInfo),
logger: logger,
config: config,
scyllaClient: scyllaClient,
secretsStore: secretsStore,
clusterProvider: clusterProvider,
nodeInfoCache: make(map[clusterIDHost]nodeInfo),
logger: logger,
}, nil
}

Expand Down Expand Up @@ -354,7 +364,7 @@ func (s *Service) pingCQL(ctx context.Context, clusterID uuid.UUID, host string,

tlsConfig := ni.tlsConfig(cqlPing)
if tlsConfig != nil {
config.Addr = ni.CQLSSLAddr(host)
config.Addr = tlsConfig.Address
config.TLSConfig = tlsConfig.Clone()
}

Expand Down Expand Up @@ -399,6 +409,12 @@ func (s *Service) nodeInfo(ctx context.Context, clusterID uuid.UUID, host string
if ni, ok := s.nodeInfoCache[key]; ok && now.Before(ni.Expires) {
return ni, nil
}

c, err := s.clusterProvider(ctx, clusterID)
if err != nil {
return nodeInfo{}, nil
}

client, err := s.scyllaClient(ctx, clusterID)
if err != nil {
return nodeInfo{}, errors.Wrap(err, "create scylla client")
Expand All @@ -409,13 +425,20 @@ func (s *Service) nodeInfo(ctx context.Context, clusterID uuid.UUID, host string
return nodeInfo{}, errors.Wrap(err, "fetch node info")
}

ni.TLSConfig = make(map[pingType]*tls.Config, 2)
ni.TLSConfig = make(map[pingType]*tlsConfigWithAddress, 2)
for _, p := range []pingType{alternatorPing, cqlPing} {
var tlsEnabled, clientCertAuth bool
var address string
if p == cqlPing {
address = ni.CQLAddr(host)
tlsEnabled, clientCertAuth = ni.CQLTLSEnabled()
tlsEnabled = tlsEnabled && !c.ForceTLSDisabled
if tlsEnabled && !c.ForceNonSSLSessionPort {
address = ni.CQLSSLAddr(host)
}
} else if p == alternatorPing {
tlsEnabled, clientCertAuth = ni.AlternatorTLSEnabled()
address = ni.AlternatorAddr(host)
}
if tlsEnabled {
tlsConfig, err := s.tlsConfig(clusterID, clientCertAuth)
Expand All @@ -427,7 +450,10 @@ func (s *Service) nodeInfo(ctx context.Context, clusterID uuid.UUID, host string
"cluster_id", clusterID,
)
} else {
ni.TLSConfig[p] = tlsConfig
ni.TLSConfig[p] = &tlsConfigWithAddress{
Config: tlsConfig,
Address: address,
}
}
}
}
Expand All @@ -439,7 +465,7 @@ func (s *Service) nodeInfo(ctx context.Context, clusterID uuid.UUID, host string
}

func (s *Service) tlsConfig(clusterID uuid.UUID, clientCertAuth bool) (*tls.Config, error) {
cfg := &tls.Config{
cfg := tls.Config{
InsecureSkipVerify: true,
}

Expand All @@ -457,7 +483,7 @@ func (s *Service) tlsConfig(clusterID uuid.UUID, clientCertAuth bool) (*tls.Conf
cfg.Certificates = []tls.Certificate{keyPair}
}

return cfg, nil
return &cfg, nil
}

func (s *Service) cqlCreds(ctx context.Context, clusterID uuid.UUID) *secrets.CQLCreds {
Expand Down Expand Up @@ -486,7 +512,7 @@ func (s *Service) InvalidateCache(clusterID uuid.UUID) {
s.cacheMu.Unlock()
}

func (ni nodeInfo) tlsConfig(pt pingType) *tls.Config {
func (ni nodeInfo) tlsConfig(pt pingType) *tlsConfigWithAddress {
return ni.TLSConfig[pt]
}

Expand Down
18 changes: 14 additions & 4 deletions pkg/service/healthcheck/service_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/scylladb/go-log"
"github.com/scylladb/scylla-manager/v3/pkg/metrics"
"github.com/scylladb/scylla-manager/v3/pkg/service/cluster"
"go.uber.org/zap/zapcore"

"github.com/scylladb/scylla-manager/v3/pkg/schema/table"
Expand All @@ -41,9 +43,12 @@ func TestStatusIntegration(t *testing.T) {
defer session.Close()

clusterID := uuid.MustRandom()

s := store.NewTableStore(session, table.Secrets)
testStatusIntegration(t, clusterID, s)
clusterSvc, err := cluster.NewService(session, metrics.NewClusterMetrics(), s, scyllaclient.DefaultTimeoutConfig(), log.NewDevelopment())
if err != nil {
t.Fatal(err)
}
testStatusIntegration(t, clusterID, clusterSvc.GetClusterByID, s)
}

func TestStatusWithCQLCredentialsIntegration(t *testing.T) {
Expand All @@ -65,11 +70,15 @@ func TestStatusWithCQLCredentialsIntegration(t *testing.T) {
}); err != nil {
t.Fatal(err)
}
clusterSvc, err := cluster.NewService(session, metrics.NewClusterMetrics(), s, scyllaclient.DefaultTimeoutConfig(), log.NewDevelopment())
if err != nil {
t.Fatal(err)
}

testStatusIntegration(t, clusterID, s)
testStatusIntegration(t, clusterID, clusterSvc.GetClusterByID, s)
}

func testStatusIntegration(t *testing.T, clusterID uuid.UUID, secretsStore store.Store) {
func testStatusIntegration(t *testing.T, clusterID uuid.UUID, clusterProvider cluster.ProviderFunc, secretsStore store.Store) {
logger := log.NewDevelopmentWithLevel(zapcore.InfoLevel).Named("healthcheck")

// Tests here do not test the dynamic t/o functionality
Expand Down Expand Up @@ -97,6 +106,7 @@ func testStatusIntegration(t *testing.T, clusterID uuid.UUID, secretsStore store
return scyllaclient.NewClient(sc, logger.Named("scylla"))
},
secretsStore,
clusterProvider,
logger,
)
if err != nil {
Expand Down

0 comments on commit 6a11ea1

Please sign in to comment.