diff --git a/cmd/wsh/cmd/wshcmd-ssh.go b/cmd/wsh/cmd/wshcmd-ssh.go index 25dad3c098..4eb1d42a4e 100644 --- a/cmd/wsh/cmd/wshcmd-ssh.go +++ b/cmd/wsh/cmd/wshcmd-ssh.go @@ -1,4 +1,4 @@ -// Copyright 2025, Command Line Inc. +// Copyright 2026, Command Line Inc. // SPDX-License-Identifier: Apache-2.0 package cmd @@ -7,6 +7,7 @@ import ( "fmt" "github.com/spf13/cobra" + "github.com/wavetermdev/waveterm/pkg/remote" "github.com/wavetermdev/waveterm/pkg/waveobj" "github.com/wavetermdev/waveterm/pkg/wconfig" "github.com/wavetermdev/waveterm/pkg/wshrpc" @@ -15,6 +16,8 @@ import ( var ( identityFiles []string + sshLogin string + sshPort string newBlock bool ) @@ -28,6 +31,8 @@ var sshCmd = &cobra.Command{ func init() { sshCmd.Flags().StringArrayVarP(&identityFiles, "identityfile", "i", []string{}, "add an identity file for publickey authentication") + sshCmd.Flags().StringVarP(&sshLogin, "login", "l", "", "set the remote login name") + sshCmd.Flags().StringVarP(&sshPort, "port", "p", "", "set the remote port") sshCmd.Flags().BoolVarP(&newBlock, "new", "n", false, "create a new terminal block with this connection") rootCmd.AddCommand(sshCmd) } @@ -38,6 +43,11 @@ func sshRun(cmd *cobra.Command, args []string) (rtnErr error) { }() sshArg := args[0] + var err error + sshArg, err = applySSHOverrides(sshArg, sshLogin, sshPort) + if err != nil { + return err + } blockId := RpcContext.BlockId if blockId == "" && !newBlock { return fmt.Errorf("cannot determine blockid (not in JWT)") @@ -91,10 +101,27 @@ func sshRun(cmd *cobra.Command, args []string) (rtnErr error) { waveobj.MetaKey_CmdCwd: nil, }, } - err := wshclient.SetMetaCommand(RpcClient, data, nil) + err = wshclient.SetMetaCommand(RpcClient, data, nil) if err != nil { return fmt.Errorf("setting connection in block: %w", err) } WriteStderr("switched connection to %q\n", sshArg) return nil } + +func applySSHOverrides(sshArg string, login string, port string) (string, error) { + if login == "" && port == "" { + return sshArg, nil + } + opts, err := remote.ParseOpts(sshArg) + if err != nil { + return "", err + } + if login != "" { + opts.SSHUser = login + } + if port != "" { + opts.SSHPort = port + } + return opts.String(), nil +} diff --git a/cmd/wsh/cmd/wshcmd-ssh_test.go b/cmd/wsh/cmd/wshcmd-ssh_test.go new file mode 100644 index 0000000000..36da037464 --- /dev/null +++ b/cmd/wsh/cmd/wshcmd-ssh_test.go @@ -0,0 +1,75 @@ +// Copyright 2026, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package cmd + +import "testing" + +func TestApplySSHOverrides(t *testing.T) { + tests := []struct { + name string + sshArg string + login string + port string + want string + wantErr bool + }{ + { + name: "no overrides preserves target", + sshArg: "root@bar.com:2022", + want: "root@bar.com:2022", + }, + { + name: "login override replaces parsed user", + sshArg: "root@bar.com", + login: "foo", + want: "foo@bar.com", + }, + { + name: "port override replaces parsed port", + sshArg: "root@bar.com:2022", + port: "2222", + want: "root@bar.com:2222", + }, + { + name: "both overrides replace parsed user and port", + sshArg: "root@bar.com:2022", + login: "foo", + port: "2200", + want: "foo@bar.com:2200", + }, + { + name: "login override adds user to bare host", + sshArg: "bar.com", + login: "foo", + want: "foo@bar.com", + }, + { + name: "port override adds port to bare host", + sshArg: "bar.com", + port: "2200", + want: "bar.com:2200", + }, + { + name: "invalid target returns parse error when override requested", + sshArg: "bad host", + login: "foo", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := applySSHOverrides(tt.sshArg, tt.login, tt.port) + if (err != nil) != tt.wantErr { + t.Fatalf("applySSHOverrides() error = %v, wantErr %v", err, tt.wantErr) + } + if tt.wantErr { + return + } + if got != tt.want { + t.Fatalf("applySSHOverrides() = %q, want %q", got, tt.want) + } + }) + } +}