Skip to content

Commit

Permalink
added ability to run multiple instances
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew Chubatiuk committed May 4, 2023
1 parent 5e0c792 commit 368e04b
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 20 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ require (
github.com/hashicorp/yamux v0.1.1 // indirect
github.com/jinzhu/copier v0.3.5
github.com/mattn/go-isatty v0.0.18 // indirect
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/oklog/run v1.1.0 // indirect
github.com/vmihailenco/msgpack v4.0.4+incompatible // indirect
golang.org/x/crypto v0.8.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,8 @@ github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3/go.mod h1:RagcQ7I8Ie
github.com/mitchellh/go-testing-interface v0.0.0-20171004221916-a61a99592b77/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI=
github.com/mitchellh/go-testing-interface v1.14.1 h1:jrgshOhYAUVNMAJiKbEu7EqAwgJJ2JqpQmpLJOu07cU=
github.com/mitchellh/go-testing-interface v1.14.1/go.mod h1:gfgS7OtZj6MA4U1UrDRp04twqAjfvlZyCfX3sDjEym8=
github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4=
github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE=
github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA=
github.com/oklog/run v1.1.0 h1:GEenZ1cK0+q0+wsJew9qUg/DyD8k3JzYsZAi5gYi2mA=
github.com/oklog/run v1.1.0/go.mod h1:sVPdnTZT1zYwAJeCMu2Th4T21pA3FPOQRfWjQlk7DVU=
Expand Down
4 changes: 3 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,14 @@ func main() {
var addr string
var ppid int
var proto string
var name string
var err error

log.SetFlags(0)

flag.IntVar(&ppid, "ppid", 0, "parent process pid")
flag.StringVar(&addr, "addr", os.Getenv("TF_SSH_PROVIDER_TUNNEL_ADDR"), "set rpc server address")
flag.StringVar(&name, "name", os.Getenv("TF_SSH_PROVIDER_TUNNEL_NAME"), "set rpc server name")
flag.StringVar(&proto, "proto", os.Getenv("TF_SSH_PROVIDER_TUNNEL_PROTO"), "set rpc server protocol")
flag.Parse()
if ppid == 0 {
Expand All @@ -65,7 +67,7 @@ func main() {
log.Fatalf("[ERROR] RPC server address wasn't set")
}
var sshTunnel ssh.SSHTunnel
if err := sshTunnel.Run(proto, addr, ppid); err != nil {
if err := sshTunnel.Run(proto, name, addr, ppid); err != nil {
log.Fatalf("[ERROR] failed to start SSH Tunnel:\n%s", err)
}
}
Expand Down
36 changes: 25 additions & 11 deletions provider/data_source_ssh_tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/hashicorp/terraform-plugin-framework/path"
"github.com/hashicorp/terraform-plugin-framework/schema/validator"
"github.com/hashicorp/terraform-plugin-framework/types"
"github.com/mitchellh/hashstructure/v2"
"github.com/stefansundin/terraform-provider-ssh/ssh"
)

Expand Down Expand Up @@ -193,22 +194,34 @@ func (d *SSHTunnelDataSource) Read(ctx context.Context, req datasource.ReadReque
data.Remote.Host = types.StringValue("localhost")
}

d.tunnel.Local = data.Local.ToEndpoint()
d.tunnel.Remote = data.Remote.ToEndpoint()

proto := "tcp"
if d.tunnel.Local.Socket != "" {
if data.Local.Socket.ValueString() != "" {
proto = "unix"
}

tunnelServer := ssh.NewSSHTunnelServer(d.tunnel)
tunnelServerInbound, err := net.Listen(proto, d.tunnel.Local.RandomPortString())
tunnel := &ssh.SSHTunnel{
User: d.tunnel.User,
Auth: d.tunnel.Auth,
Server: d.tunnel.Server,
Local: data.Local.ToEndpoint(),
Remote: data.Remote.ToEndpoint(),
}

tunnelServer := ssh.NewSSHTunnelServer(tunnel)
tunnelServerInbound, err := net.Listen(proto, tunnel.Local.RandomPortString())
if err != nil {
resp.Diagnostics.AddError("proxy process error", err.Error())
return
}

if err = rpc.Register(tunnelServer); err != nil {
hash, err := hashstructure.Hash(tunnel.Remote, hashstructure.FormatV2, nil)
if err != nil {
resp.Diagnostics.AddError("rpc service name error", err.Error())
}

serviceName := fmt.Sprintf("SSHTunnelServer.%d", hash)

if err = rpc.RegisterName(serviceName, tunnelServer); err != nil {
resp.Diagnostics.AddError("rpc registration error", err.Error())
return
}
Expand All @@ -229,6 +242,7 @@ func (d *SSHTunnelDataSource) Read(ctx context.Context, req datasource.ReadReque
env := []string{
fmt.Sprintf("TF_SSH_PROVIDER_TUNNEL_PROTO=%s", proto),
fmt.Sprintf("TF_SSH_PROVIDER_TUNNEL_ADDR=%s", tunnelServerInbound.Addr().String()),
fmt.Sprintf("TF_SSH_PROVIDER_TUNNEL_NAME=%s", serviceName),
fmt.Sprintf("TF_SSH_PROVIDER_TUNNEL_PPID=%d", os.Getppid()),
}
cmd.Env = append(cmd.Env, env...)
Expand Down Expand Up @@ -264,10 +278,10 @@ func (d *SSHTunnelDataSource) Read(ctx context.Context, req datasource.ReadReque

tunnelServerInbound.Close()

log.Printf("[DEBUG] local port: %v", d.tunnel.Local.Port)
data.Local.Port = types.Int64Value(int64(d.tunnel.Local.Port))
data.Local.Address = types.StringValue(d.tunnel.Local.Address())
data.Id = types.StringValue(d.tunnel.Local.Address())
log.Printf("[DEBUG] local port: %v", tunnel.Local.Port)
data.Local.Port = types.Int64Value(int64(tunnel.Local.Port))
data.Local.Address = types.StringValue(tunnel.Local.Address())
data.Id = types.StringValue(tunnel.Local.Address())
resp.Diagnostics.Append(resp.State.Set(ctx, &data)...)

return
Expand Down
10 changes: 5 additions & 5 deletions provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,6 @@ func (p *SSHProvider) Configure(ctx context.Context, req provider.ConfigureReque
}

sshTunnel.Auth = []ssh.SSHAuth{}
if authSock != "" {
sshTunnel.Auth = append(sshTunnel.Auth, ssh.SSHAuthSock{
Path: authSock,
})
}
privateKey := ssh.SSHPrivateKey{}
if data.Auth.PrivateKey.Content.ValueString() != "" {
privateKey.PrivateKey = data.Auth.PrivateKey.Content.ValueString()
Expand All @@ -191,6 +186,11 @@ func (p *SSHProvider) Configure(ctx context.Context, req provider.ConfigureReque
privateKey.Certificate = data.Auth.PrivateKey.Certificate.ValueString()
}
sshTunnel.Auth = append(sshTunnel.Auth, privateKey)
if authSock != "" {
sshTunnel.Auth = append(sshTunnel.Auth, ssh.SSHAuthSock{
Path: authSock,
})
}
if data.Auth.Password.ValueString() != "" {
sshTunnel.Auth = append(sshTunnel.Auth, ssh.SSHPassword{
Password: data.Auth.Password.ValueString(),
Expand Down
6 changes: 3 additions & 3 deletions ssh/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ type SSHTunnel struct {
Auth []SSHAuth
}

func (st *SSHTunnel) Run(proto, serverAddress string, ppid int) error {
func (st *SSHTunnel) Run(proto, serverName, serverAddress string, ppid int) error {
log.Println("[DEBUG] creating SSH Tunnel")
var ack bool
gob.Register(SSHPrivateKey{})
Expand All @@ -139,7 +139,7 @@ func (st *SSHTunnel) Run(proto, serverAddress string, ppid int) error {
}

defer client.Close()
err = client.Call("SSHTunnelServer.GetSSHTunnel", &ack, &st)
err = client.Call(fmt.Sprintf("%s.GetSSHTunnel", serverName), &ack, &st)
if err != nil {
log.Fatalf("[ERROR] failed to execute a RPC call: %v", err)
}
Expand Down Expand Up @@ -197,7 +197,7 @@ func (st *SSHTunnel) Run(proto, serverAddress string, ppid int) error {
}

log.Printf("[DEBUG] sending PutSSHReady RPC call with port %d", st.Local.Port)
err = client.Call("SSHTunnelServer.PutSSHReady", st.Local.Port, &ack)
err = client.Call(fmt.Sprintf("%s.PutSSHReady", serverName), st.Local.Port, &ack)
if err != nil {
log.Fatal("[ERROR] failed to execute a RPC call:\n", err)
}
Expand Down

0 comments on commit 368e04b

Please sign in to comment.