Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Require certificate registered within cluster before choosing CQL SSL #3699

Merged
merged 1 commit into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 17 additions & 24 deletions pkg/scyllaclient/client_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"context"
"net"
"net/url"
"time"

"github.com/pkg/errors"
scyllaversion "github.com/scylladb/scylla-manager/v3/pkg/util/version"
Expand Down Expand Up @@ -47,7 +46,17 @@ func (c *Client) AnyNodeInfo(ctx context.Context) (*NodeInfo, error) {
// is added.
// `fallback` argument is used in case any of above addresses is zero address.
func (ni *NodeInfo) CQLAddr(fallback string) string {
addr, port := ni.cqlAddr(fallback), ni.CQLPort(fallback)
addr, port := ni.cqlAddr(fallback), ni.CQLPort()
return net.JoinHostPort(addr, port)
}

// CQLSSLAddr returns CQL SSL address from NodeInfo.
// Scylla can have separate rpc_address (CQL), listen_address and respectfully
// broadcast_rpc_address and broadcast_address if some 3rd party routing
// is added.
// `fallback` argument is used in case any of above addresses is zero address.
func (ni *NodeInfo) CQLSSLAddr(fallback string) string {
addr, port := ni.cqlAddr(fallback), ni.CQLSSLPort()
return net.JoinHostPort(addr, port)
}

Expand All @@ -71,31 +80,15 @@ func (ni *NodeInfo) cqlAddr(fallback string) string {
}

// CQLPort returns CQL port from NodeInfo.
// `fallbackAddress` argument is needed for Scylla bug workaround, see CQLAddr for description.
func (ni *NodeInfo) CQLPort(fallbackAddress string) string {
if ni.ClientEncryptionEnabled {
// Scylla API always returns non-empty NativeTransportPortSSL even when
// value is explicitly disabled in configuration file.
// This makes impossible to determine which port is being used for CQL
// frontend. To workaround it, we try to dial SSL port when
// client encryption is enabled. If any error happens, assume this port
// is not used.
// Ref: https://github.com/scylladb/scylla/issues/7206

d := &net.Dialer{
Timeout: time.Second,
}
addr := net.JoinHostPort(ni.cqlAddr(fallbackAddress), ni.NativeTransportPortSsl)
c, err := d.Dial("tcp", addr)
if err != nil {
return ni.NativeTransportPort
}
defer c.Close()
return ni.NativeTransportPortSsl
}
func (ni *NodeInfo) CQLPort() string {
return ni.NativeTransportPort
}

// CQLSSLPort returns CQL SSL port from NodeInfo.
func (ni *NodeInfo) CQLSSLPort() string {
return ni.NativeTransportPortSsl
}

// AlternatorEnabled returns if Alternator is enabled on host.
func (ni *NodeInfo) AlternatorEnabled() bool {
return (ni.AlternatorHTTPSPort != "0" && ni.AlternatorHTTPSPort != "") ||
Expand Down
137 changes: 75 additions & 62 deletions pkg/scyllaclient/client_agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (

"github.com/google/go-cmp/cmp"
"github.com/scylladb/scylla-manager/v3/pkg/scyllaclient"
"go.uber.org/atomic"
)

const fallback = "4.3.2.1"
Expand Down Expand Up @@ -80,26 +79,6 @@ func TestNodeInfoCQLAddr(t *testing.T) {
},
GoldenAddress: net.JoinHostPort(fallback, "1234"),
},
{
Name: "Native Transport Port SSL with client encryption enabled without server listening on",
NodeInfo: &scyllaclient.NodeInfo{
NativeTransportPort: "4321",
NativeTransportPortSsl: "1234",
ListenAddress: "1.2.3.4",
ClientEncryptionEnabled: true,
},
GoldenAddress: net.JoinHostPort("1.2.3.4", "4321"),
},
{
Name: "Native Transport Port SSL with client encryption disabled",
NodeInfo: &scyllaclient.NodeInfo{
NativeTransportPort: "4321",
NativeTransportPortSsl: "1234",
ListenAddress: "1.2.3.4",
ClientEncryptionEnabled: false,
},
GoldenAddress: net.JoinHostPort("1.2.3.4", "4321"),
},
}

for i := range table {
Expand All @@ -115,51 +94,85 @@ func TestNodeInfoCQLAddr(t *testing.T) {
}
}

// Test workaround used in NodeInfo.CQLPort().
func TestNodeInfoCQLAddrNativeTransportPortSSL(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer l.Close()
func TestNodeInfoCQLSSLAddr(t *testing.T) {
t.Parallel()

address, port, err := net.SplitHostPort(l.Addr().String())
if err != nil {
t.Fatal(err)
table := []struct {
Name string
NodeInfo *scyllaclient.NodeInfo
GoldenAddress string
}{
{
Name: "Broadcast RPC address is set",
NodeInfo: &scyllaclient.NodeInfo{
BroadcastRPCAddress: "1.2.3.4",
RPCAddress: "1.2.3.5",
ListenAddress: "1.2.3.6",
NativeTransportPortSsl: "1234",
},
GoldenAddress: "1.2.3.4:1234",
},
{
Name: "RPC address is set",
NodeInfo: &scyllaclient.NodeInfo{
NativeTransportPortSsl: "1234",
RPCAddress: "1.2.3.5",
ListenAddress: "1.2.3.6",
},
GoldenAddress: "1.2.3.5:1234",
},
{
Name: "Listen Address is set",
NodeInfo: &scyllaclient.NodeInfo{
NativeTransportPortSsl: "1234",
ListenAddress: "1.2.3.6",
},
GoldenAddress: "1.2.3.6:1234",
},
{
Name: "Fallback is returned when RPC Address is IPv4 zero",
NodeInfo: &scyllaclient.NodeInfo{
NativeTransportPortSsl: "1234",
RPCAddress: "0.0.0.0",
},
GoldenAddress: net.JoinHostPort(fallback, "1234"),
},
{
Name: "Fallback is returned when RPC Address is IPv6 zero",
NodeInfo: &scyllaclient.NodeInfo{
NativeTransportPortSsl: "1234",
RPCAddress: "::0",
},
GoldenAddress: net.JoinHostPort(fallback, "1234"),
},
{
Name: "Fallback is returned when Listen Address is IPv4 zero",
NodeInfo: &scyllaclient.NodeInfo{
NativeTransportPortSsl: "1234",
ListenAddress: "0.0.0.0",
},
GoldenAddress: net.JoinHostPort(fallback, "1234"),
},
{
Name: "Fallback is returned when Listen Address is IPv6 zero",
NodeInfo: &scyllaclient.NodeInfo{
NativeTransportPortSsl: "1234",
ListenAddress: "::0",
},
GoldenAddress: net.JoinHostPort(fallback, "1234"),
},
}

var (
connections atomic.Int64
ready = make(chan struct{})
)
go func() {
for {
c, err := l.Accept()
if err != nil {
return
}
connections.Inc()
_ = c.Close()
close(ready)
break
}
}()

ni := &scyllaclient.NodeInfo{
NativeTransportPort: "4321",
NativeTransportPortSsl: port,
ListenAddress: address,
ClientEncryptionEnabled: true,
}
addr := ni.CQLAddr(fallback)
golden := net.JoinHostPort(ni.ListenAddress, ni.NativeTransportPortSsl)
if addr != golden {
t.Errorf("expected %s address, got %s", golden, addr)
}
for i := range table {
test := table[i]
t.Run(test.Name, func(t *testing.T) {
t.Parallel()

<-ready
if c := connections.Load(); c == 0 {
t.Errorf("expected connection during figuring out CQL port got %d", c)
addr := test.NodeInfo.CQLSSLAddr(fallback)
if addr != test.GoldenAddress {
t.Errorf("expected %s address, got %s", test.GoldenAddress, addr)
}
})
}
}

Expand Down
37 changes: 30 additions & 7 deletions pkg/service/cluster/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"crypto/tls"
"fmt"
"sort"
"strconv"

"github.com/gocql/gocql"
"github.com/pkg/errors"
Expand Down Expand Up @@ -549,6 +550,7 @@ func (s *Service) GetSession(ctx context.Context, clusterID uuid.UUID) (session
}

scyllaCluster := gocql.NewCluster(sessionHosts...)
cqlPort := ni.CQLPort()

if ni.CqlPasswordProtected {
credentials := secrets.CQLCreds{
Expand All @@ -568,20 +570,41 @@ func (s *Service) GetSession(ctx context.Context, clusterID uuid.UUID) (session
}
}

if ni.ClientEncryptionEnabled {
keyPair, err := s.loadTLSIdentity(clusterID)
if err != nil && !errors.Is(err, service.ErrNotFound) {
return session, err
}

if ni.ClientEncryptionEnabled && !ni.ClientEncryptionRequireAuth {
cqlPort = ni.CQLSSLPort()
scyllaCluster.SslOpts = &gocql.SslOptions{
Config: &tls.Config{
InsecureSkipVerify: true,
},
}
if ni.ClientEncryptionRequireAuth {
keyPair, err := s.loadTLSIdentity(clusterID)
if err != nil {
return session, err
}
scyllaCluster.SslOpts.Config.Certificates = []tls.Certificate{keyPair}
}

if ni.ClientEncryptionEnabled && ni.ClientEncryptionRequireAuth && !errors.Is(err, service.ErrNotFound) {
cqlPort = ni.CQLSSLPort()
scyllaCluster.SslOpts = &gocql.SslOptions{
Config: &tls.Config{
InsecureSkipVerify: true,
},
}
scyllaCluster.SslOpts.Config.Certificates = []tls.Certificate{keyPair}
}

if ni.ClientEncryptionEnabled && ni.ClientEncryptionRequireAuth && errors.Is(err, service.ErrNotFound) {
s.logger.Info(ctx, "Client encryption is enabled, but Cluster wasn't registered with certificate in Scylla Manager, falling back to nonSSL port.",
"cluster_id", clusterID,
)
}

p, err := strconv.Atoi(cqlPort)
if err != nil {
return session, errors.Wrap(err, "parse cql port")
}
scyllaCluster.Port = p

return gocqlx.WrapSession(scyllaCluster.CreateSession())
}
Expand Down
11 changes: 9 additions & 2 deletions pkg/service/healthcheck/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ func (s *Service) pingCQL(ctx context.Context, clusterID uuid.UUID, host string,

tlsConfig := ni.tlsConfig(cqlPing)
if tlsConfig != nil {
config.Addr = ni.CQLSSLAddr(host)
config.TLSConfig = tlsConfig.Clone()
}

Expand Down Expand Up @@ -418,10 +419,16 @@ func (s *Service) nodeInfo(ctx context.Context, clusterID uuid.UUID, host string
}
if tlsEnabled {
tlsConfig, err := s.tlsConfig(clusterID, clientCertAuth)
if err != nil {
if err != nil && !errors.Is(err, service.ErrNotFound) {
return ni, errors.Wrap(err, "fetch TLS config")
}
ni.TLSConfig[p] = tlsConfig
if clientCertAuth && errors.Is(err, service.ErrNotFound) {
s.logger.Info(ctx, "Client encryption is enabled, but Cluster wasn't registered with certificate in Scylla Manager, falling back to nonSSL port.",
"cluster_id", clusterID,
)
} else {
ni.TLSConfig[p] = tlsConfig
}
}
}

Expand Down
Loading