Skip to content

Commit

Permalink
fix(healtcheck): respect force-tls-disabled and force-non-ssl-session…
Browse files Browse the repository at this point in the history
…-port cluster properties
  • Loading branch information
karol-kokoszka committed Feb 21, 2024
1 parent ced1728 commit 7eb7458
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 31 deletions.
1 change: 1 addition & 0 deletions pkg/cmd/scylla-manager/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ func (s *server) makeServices() error {
s.config.Healthcheck,
s.clusterSvc.Client,
secretsStore,
s.clusterSvc.GetClusterByID,
s.logger.Named("healthcheck"),
)
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions pkg/service/cluster/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ import (
"go.uber.org/multierr"
)

// ProviderFunc defines the function that will be used by other services to get current cluster data.
type ProviderFunc func(ctx context.Context, id uuid.UUID) (*Cluster, error)

// ChangeType specifies type on Change.
type ChangeType int8

Expand Down
63 changes: 43 additions & 20 deletions pkg/service/healthcheck/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/pkg/errors"
"github.com/scylladb/go-log"
"github.com/scylladb/scylla-manager/v3/pkg/service/cluster"
"golang.org/x/sync/errgroup"

"github.com/scylladb/scylla-manager/v3/pkg/ping"
Expand Down Expand Up @@ -51,17 +52,23 @@ func (pt pingType) String() string {
return "unknown"
}

type tlsConfigWithAddress struct {
*tls.Config
Address string
}

type nodeInfo struct {
*scyllaclient.NodeInfo
TLSConfig map[pingType]*tls.Config
TLSConfig map[pingType]*tlsConfigWithAddress
Expires time.Time
}

// Service manages health checks.
type Service struct {
config Config
scyllaClient scyllaclient.ProviderFunc
secretsStore store.Store
config Config
scyllaClient scyllaclient.ProviderFunc
secretsStore store.Store
clusterProvider cluster.ProviderFunc

cacheMu sync.Mutex
// fields below are protected by cacheMu
Expand All @@ -70,17 +77,20 @@ type Service struct {
logger log.Logger
}

func NewService(config Config, scyllaClient scyllaclient.ProviderFunc, secretsStore store.Store, logger log.Logger) (*Service, error) {
func NewService(config Config, scyllaClient scyllaclient.ProviderFunc, secretsStore store.Store,
clusterProvider cluster.ProviderFunc, logger log.Logger,
) (*Service, error) {
if scyllaClient == nil {
return nil, errors.New("invalid scylla provider")
}

return &Service{
config: config,
scyllaClient: scyllaClient,
secretsStore: secretsStore,
nodeInfoCache: make(map[clusterIDHost]nodeInfo),
logger: logger,
config: config,
scyllaClient: scyllaClient,
secretsStore: secretsStore,
clusterProvider: clusterProvider,
nodeInfoCache: make(map[clusterIDHost]nodeInfo),
logger: logger,
}, nil
}

Expand Down Expand Up @@ -354,7 +364,7 @@ func (s *Service) pingCQL(ctx context.Context, clusterID uuid.UUID, host string,

tlsConfig := ni.tlsConfig(cqlPing)
if tlsConfig != nil {
config.Addr = ni.CQLSSLAddr(host)
config.Addr = tlsConfig.Address
config.TLSConfig = tlsConfig.Clone()
}

Expand Down Expand Up @@ -399,6 +409,12 @@ func (s *Service) nodeInfo(ctx context.Context, clusterID uuid.UUID, host string
if ni, ok := s.nodeInfoCache[key]; ok && now.Before(ni.Expires) {
return ni, nil
}

c, err := s.clusterProvider(ctx, clusterID)
if err != nil {
return nodeInfo{}, err
}

client, err := s.scyllaClient(ctx, clusterID)
if err != nil {
return nodeInfo{}, errors.Wrap(err, "create scylla client")
Expand All @@ -409,25 +425,32 @@ func (s *Service) nodeInfo(ctx context.Context, clusterID uuid.UUID, host string
return nodeInfo{}, errors.Wrap(err, "fetch node info")
}

ni.TLSConfig = make(map[pingType]*tls.Config, 2)
ni.TLSConfig = make(map[pingType]*tlsConfigWithAddress, 2)
for _, p := range []pingType{alternatorPing, cqlPing} {
var tlsEnabled, clientCertAuth bool
var address string
if p == cqlPing {
address = ni.CQLAddr(host)
tlsEnabled, clientCertAuth = ni.CQLTLSEnabled()
tlsEnabled = tlsEnabled && !c.ForceTLSDisabled
if tlsEnabled && !c.ForceNonSSLSessionPort {
address = ni.CQLSSLAddr(host)
}
} else if p == alternatorPing {
tlsEnabled, clientCertAuth = ni.AlternatorTLSEnabled()
address = ni.AlternatorAddr(host)
}
if tlsEnabled {
tlsConfig, err := s.tlsConfig(clusterID, clientCertAuth)
if err != nil && !errors.Is(err, service.ErrNotFound) {
return ni, errors.Wrap(err, "fetch TLS config")
}
if clientCertAuth && errors.Is(err, service.ErrNotFound) {
s.logger.Info(ctx, "Client encryption is enabled, but Cluster wasn't registered with certificate in Scylla Manager, falling back to nonSSL port.",
"cluster_id", clusterID,
)
} else {
ni.TLSConfig[p] = tlsConfig
return nodeInfo{}, errors.Wrap(err, "client encryption is enabled, but certificate is missing")
}
ni.TLSConfig[p] = &tlsConfigWithAddress{
Config: tlsConfig,
Address: address,
}
}
}
Expand All @@ -439,7 +462,7 @@ func (s *Service) nodeInfo(ctx context.Context, clusterID uuid.UUID, host string
}

func (s *Service) tlsConfig(clusterID uuid.UUID, clientCertAuth bool) (*tls.Config, error) {
cfg := &tls.Config{
cfg := tls.Config{
InsecureSkipVerify: true,
}

Expand All @@ -457,7 +480,7 @@ func (s *Service) tlsConfig(clusterID uuid.UUID, clientCertAuth bool) (*tls.Conf
cfg.Certificates = []tls.Certificate{keyPair}
}

return cfg, nil
return &cfg, nil
}

func (s *Service) cqlCreds(ctx context.Context, clusterID uuid.UUID) *secrets.CQLCreds {
Expand Down Expand Up @@ -486,7 +509,7 @@ func (s *Service) InvalidateCache(clusterID uuid.UUID) {
s.cacheMu.Unlock()
}

func (ni nodeInfo) tlsConfig(pt pingType) *tls.Config {
func (ni nodeInfo) tlsConfig(pt pingType) *tlsConfigWithAddress {
return ni.TLSConfig[pt]
}

Expand Down
41 changes: 30 additions & 11 deletions pkg/service/healthcheck/service_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/scylladb/go-log"
"github.com/scylladb/scylla-manager/v3/pkg/metrics"
"github.com/scylladb/scylla-manager/v3/pkg/service/cluster"
"go.uber.org/zap/zapcore"

"github.com/scylladb/scylla-manager/v3/pkg/schema/table"
"github.com/scylladb/scylla-manager/v3/pkg/scyllaclient"
"github.com/scylladb/scylla-manager/v3/pkg/secrets"
"github.com/scylladb/scylla-manager/v3/pkg/store"
. "github.com/scylladb/scylla-manager/v3/pkg/testutils"
. "github.com/scylladb/scylla-manager/v3/pkg/testutils/db"
Expand All @@ -40,10 +41,22 @@ func TestStatusIntegration(t *testing.T) {
session := CreateScyllaManagerDBSession(t)
defer session.Close()

clusterID := uuid.MustRandom()

s := store.NewTableStore(session, table.Secrets)
testStatusIntegration(t, clusterID, s)
clusterSvc, err := cluster.NewService(session, metrics.NewClusterMetrics(), s, scyllaclient.DefaultTimeoutConfig(), log.NewDevelopment())
if err != nil {
t.Fatal(err)
}

c := &cluster.Cluster{
Host: "192.168.200.11",
AuthToken: "token",
}
err = clusterSvc.PutCluster(context.Background(), c)
if err != nil {
t.Fatal(err)
}

testStatusIntegration(t, c.ID, clusterSvc.GetClusterByID, s)
}

func TestStatusWithCQLCredentialsIntegration(t *testing.T) {
Expand All @@ -55,21 +68,26 @@ func TestStatusWithCQLCredentialsIntegration(t *testing.T) {
session := CreateScyllaManagerDBSession(t)
defer session.Close()

clusterID := uuid.MustRandom()

s := store.NewTableStore(session, table.Secrets)
if err := s.Put(&secrets.CQLCreds{
ClusterID: clusterID,
clusterSvc, err := cluster.NewService(session, metrics.NewClusterMetrics(), s, scyllaclient.DefaultTimeoutConfig(), log.NewDevelopment())
if err != nil {
t.Fatal(err)
}
c := &cluster.Cluster{
Host: "192.168.200.11",
AuthToken: "token",
Username: username,
Password: password,
}); err != nil {
}
err = clusterSvc.PutCluster(context.Background(), c)
if err != nil {
t.Fatal(err)
}

testStatusIntegration(t, clusterID, s)
testStatusIntegration(t, c.ID, clusterSvc.GetClusterByID, s)
}

func testStatusIntegration(t *testing.T, clusterID uuid.UUID, secretsStore store.Store) {
func testStatusIntegration(t *testing.T, clusterID uuid.UUID, clusterProvider cluster.ProviderFunc, secretsStore store.Store) {
logger := log.NewDevelopmentWithLevel(zapcore.InfoLevel).Named("healthcheck")

// Tests here do not test the dynamic t/o functionality
Expand Down Expand Up @@ -97,6 +115,7 @@ func testStatusIntegration(t *testing.T, clusterID uuid.UUID, secretsStore store
return scyllaclient.NewClient(sc, logger.Named("scylla"))
},
secretsStore,
clusterProvider,
logger,
)
if err != nil {
Expand Down

0 comments on commit 7eb7458

Please sign in to comment.