diff --git a/cmd/cloudflared/access/carrier.go b/cmd/cloudflared/access/carrier.go index 52aacc5655f..9fb89d79f1b 100644 --- a/cmd/cloudflared/access/carrier.go +++ b/cmd/cloudflared/access/carrier.go @@ -5,8 +5,10 @@ import ( "fmt" "io" "net/http" + "os" "strings" + "github.com/mitchellh/go-homedir" "github.com/pkg/errors" "github.com/rs/zerolog" "github.com/urfave/cli/v2" @@ -69,6 +71,21 @@ func ssh(c *cli.Context) error { } log := logger.CreateSSHLoggerFromContext(c, outputTerminal) + if pidFile := c.String(sshPidFileFlag); pidFile != "" { + expandedPidFile, err := homedir.Expand(pidFile) + if err != nil { + log.Err(err).Msg("unable to expand pidfile path") + } else if err := writePidFile(expandedPidFile, log); err != nil { + log.Err(err).Msg("failed to write pidfile") + } else { + defer func() { + if err := os.Remove(expandedPidFile); err != nil { + log.Err(err).Msg("failed to remove pidfile") + } + }() + } + } + // get the hostname from the cmdline and error out if its not provided rawHostName := c.String(sshHostnameFlag) url, err := parseURL(rawHostName) @@ -145,3 +162,17 @@ func ssh(c *cli.Context) error { } return carrier.StartClient(wsConn, s, options) } + +// writePidFile writes the current process ID to the given path. +func writePidFile(path string, log *zerolog.Logger) error { + file, err := os.Create(path) + if err != nil { + return fmt.Errorf("unable to create pidfile %q: %w", path, err) + } + defer file.Close() + if _, err := fmt.Fprintf(file, "%d", os.Getpid()); err != nil { + return fmt.Errorf("unable to write pid to %q: %w", path, err) + } + log.Info().Str("pidfile", path).Msg("wrote pidfile") + return nil +} diff --git a/cmd/cloudflared/access/cmd.go b/cmd/cloudflared/access/cmd.go index 636b9288e27..29c265bc96f 100644 --- a/cmd/cloudflared/access/cmd.go +++ b/cmd/cloudflared/access/cmd.go @@ -38,6 +38,7 @@ const ( sshGenCertFlag = "short-lived-cert" sshConnectTo = "connect-to" sshDebugStream = "debug-stream" + sshPidFileFlag = "pidfile" sshConfigTemplate = ` Add to your {{.Home}}/.ssh/config: @@ -204,6 +205,11 @@ func Commands() []*cli.Command { Hidden: true, Usage: "Writes up-to the max provided stream payloads to the logger as debug statements.", }, + &cli.StringFlag{ + Name: sshPidFileFlag, + Usage: "Write the application's PID to this file after startup.", + EnvVars: []string{"TUNNEL_PIDFILE"}, + }, }, }, {