terminal: Disconnect sooner and more reliably
authorAndrey Petrov <andrey.petrov@shazow.net>
Mon, 25 Jul 2016 02:56:38 +0000 (22:56 -0400)
committerAndrey Petrov <andrey.petrov@shazow.net>
Mon, 25 Jul 2016 02:56:38 +0000 (22:56 -0400)
sshd/terminal.go

index e71749becbbd09b68481e12de4df5f1fb5c52fc4..8d4b7257dc6d1a34561da7b5c55ea91130da0158 100644 (file)
@@ -4,6 +4,7 @@ import (
        "errors"
        "fmt"
        "net"
+       "sync"
        "time"
 
        "golang.org/x/crypto/ssh"
@@ -13,6 +14,9 @@ import (
 var keepaliveInterval = time.Second * 30
 var keepaliveRequest = "keepalive@ssh-chat"
 
+var ErrNoSessionChannel = errors.New("no session channel")
+var ErrNotSessionChannel = errors.New("terminal requires session channel")
+
 // Connection is an interface with fields necessary to operate an sshd host.
 type Connection interface {
        PublicKey() ssh.PublicKey
@@ -52,37 +56,43 @@ type Terminal struct {
        terminal.Terminal
        Conn    Connection
        Channel ssh.Channel
+
+       done      chan struct{}
+       closeOnce sync.Once
 }
 
 // Make new terminal from a session channel
 func NewTerminal(conn *ssh.ServerConn, ch ssh.NewChannel) (*Terminal, error) {
        if ch.ChannelType() != "session" {
-               return nil, errors.New("terminal requires session channel")
+               return nil, ErrNotSessionChannel
        }
        channel, requests, err := ch.Accept()
        if err != nil {
                return nil, err
        }
        term := Terminal{
-               *terminal.NewTerminal(channel, "Connecting..."),
-               sshConn{conn},
-               channel,
+               Terminal: *terminal.NewTerminal(channel, "Connecting..."),
+               Conn:     sshConn{conn},
+               Channel:  channel,
+
+               done: make(chan struct{}),
        }
 
        go term.listen(requests)
-       go func() {
-               // FIXME: Is this necessary?
-               conn.Wait()
-               channel.Close()
-       }()
 
        go func() {
-               for range time.Tick(keepaliveInterval) {
-                       // TODO: Could break out earlier with a select if we want, rather than waiting for an error.
-                       _, err := channel.SendRequest(keepaliveRequest, true, nil)
-                       if err != nil {
-                               // Connection is gone
-                               conn.Close()
+               // Keep-Alive Ticker
+               ticker := time.Tick(keepaliveInterval)
+               for {
+                       select {
+                       case <-ticker:
+                               _, err := channel.SendRequest(keepaliveRequest, true, nil)
+                               if err != nil {
+                                       // Connection is gone
+                                       term.Close()
+                                       return
+                               }
+                       case <-term.done:
                                return
                        }
                }
@@ -92,35 +102,29 @@ func NewTerminal(conn *ssh.ServerConn, ch ssh.NewChannel) (*Terminal, error) {
 }
 
 // Find session channel and make a Terminal from it
-func NewSession(conn *ssh.ServerConn, channels <-chan ssh.NewChannel) (term *Terminal, err error) {
+func NewSession(conn *ssh.ServerConn, channels <-chan ssh.NewChannel) (*Terminal, error) {
+       // Make a terminal from the first session found
        for ch := range channels {
                if t := ch.ChannelType(); t != "session" {
                        ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t))
                        continue
                }
 
-               term, err = NewTerminal(conn, ch)
-               if err == nil {
-                       break
-               }
-       }
-
-       if term != nil {
-               // Reject the rest.
-               // FIXME: Do we need this?
-               go func() {
-                       for ch := range channels {
-                               ch.Reject(ssh.Prohibited, "only one session allowed")
-                       }
-               }()
+               return NewTerminal(conn, ch)
        }
 
-       return term, err
+       return nil, ErrNoSessionChannel
 }
 
 // Close terminal and ssh connection
 func (t *Terminal) Close() error {
-       return t.Conn.Close()
+       var err error
+       t.closeOnce.Do(func() {
+               close(t.done)
+               t.Channel.Close()
+               err = t.Conn.Close()
+       })
+       return err
 }
 
 // Negotiate terminal type and settings