Resolve name collision to GuestX, with test.
authorAndrey Petrov <andrey.petrov@shazow.net>
Wed, 7 Jan 2015 05:42:57 +0000 (21:42 -0800)
committerAndrey Petrov <andrey.petrov@shazow.net>
Wed, 7 Jan 2015 05:42:57 +0000 (21:42 -0800)
cmd.go
host.go
host_test.go
sshd/client.go [new file with mode: 0644]
sshd/net_test.go

diff --git a/cmd.go b/cmd.go
index e7bfc68a051a4e30231532366294b0b90a31a2b5..b97ec183a1025bf68a80d56de6f61dbb9996d33c 100644 (file)
--- a/cmd.go
+++ b/cmd.go
@@ -108,6 +108,7 @@ func main() {
 
        host := NewHost(s)
        host.auth = &auth
+       host.theme = &chat.Themes[0]
 
        for _, fingerprint := range options.Admin {
                auth.Op(fingerprint)
diff --git a/host.go b/host.go
index 175ba118a19db1f90e26755aa12d4e759480e014..376c9660781c6b4fe7209499d6b10b9c17c47689 100644 (file)
--- a/host.go
+++ b/host.go
@@ -16,8 +16,12 @@ type Host struct {
        channel  *chat.Channel
        commands *chat.Commands
 
-       motd string
-       auth *Auth
+       motd  string
+       auth  *Auth
+       count int
+
+       // Default theme
+       theme *chat.Theme
 }
 
 // NewHost creates a Host on top of an existing listener.
@@ -48,7 +52,7 @@ func (h *Host) Connect(term *sshd.Terminal) {
        term.AutoCompleteCallback = h.AutoCompleteFunction
 
        user := chat.NewUserScreen(name, term)
-       user.Config.Theme = &chat.Themes[0]
+       user.Config.Theme = h.theme
        go func() {
                // Close term once user is closed.
                user.Wait()
@@ -56,14 +60,21 @@ func (h *Host) Connect(term *sshd.Terminal) {
        }()
        defer user.Close()
 
-       term.SetPrompt(GetPrompt(user))
-
        err := h.channel.Join(user)
+       if err == chat.ErrIdTaken {
+               // Try again...
+               user.SetName(fmt.Sprintf("Guest%d", h.count))
+               err = h.channel.Join(user)
+       }
        if err != nil {
                logger.Errorf("Failed to join: %s", err)
                return
        }
 
+       // Successfully joined.
+       term.SetPrompt(GetPrompt(user))
+       h.count++
+
        for {
                line, err := term.ReadLine()
                if err == io.EOF {
index 882c5f99b3eca5899a94ccc824a093fb3702f748..d86c3531523f9ff394fe2c67bf8bcf6c2fa83ce5 100644 (file)
@@ -1,11 +1,23 @@
 package main
 
 import (
+       "bufio"
+       "io"
+       "strings"
        "testing"
 
        "github.com/shazow/ssh-chat/chat"
+       "github.com/shazow/ssh-chat/sshd"
 )
 
+func stripPrompt(s string) string {
+       pos := strings.LastIndex(s, "\033[K")
+       if pos < 0 {
+               return s
+       }
+       return s[pos+3:]
+}
+
 func TestHostGetPrompt(t *testing.T) {
        var expected, actual string
 
@@ -15,13 +27,88 @@ func TestHostGetPrompt(t *testing.T) {
        actual = GetPrompt(u)
        expected = "[foo] "
        if actual != expected {
-               t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
+               t.Errorf("Got: %q; Expected: %q", actual, expected)
        }
 
        u.Config.Theme = &chat.Themes[0]
        actual = GetPrompt(u)
        expected = "[\033[38;05;2mfoo\033[0m] "
        if actual != expected {
-               t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
+               t.Errorf("Got: %q; Expected: %q", actual, expected)
+       }
+}
+
+func TestHostNameCollision(t *testing.T) {
+       key, err := sshd.NewRandomKey(512)
+       if err != nil {
+               t.Fatal(err)
+       }
+       config := sshd.MakeNoAuth()
+       config.AddHostKey(key)
+
+       s, err := sshd.ListenSSH(":0", config)
+       if err != nil {
+               t.Fatal(err)
+       }
+       host := NewHost(s)
+       go host.Serve()
+
+       done := make(chan struct{}, 1)
+
+       // First client
+       go func() {
+               err = sshd.NewClientSession(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
+                       scanner := bufio.NewScanner(r)
+
+                       // Consume the initial buffer
+                       scanner.Scan()
+                       actual := scanner.Text()
+                       if !strings.HasPrefix(actual, "[foo] ") {
+                               t.Errorf("First client failed to get 'foo' name.")
+                       }
+
+                       actual = stripPrompt(actual)
+                       expected := " * foo joined. (Connected: 1)"
+                       if actual != expected {
+                               t.Errorf("Got %q; expected %q", actual, expected)
+                       }
+
+                       // Ready for second client
+                       done <- struct{}{}
+
+                       scanner.Scan()
+                       actual = stripPrompt(scanner.Text())
+                       expected = " * Guest1 joined. (Connected: 2)"
+                       if actual != expected {
+                               t.Errorf("Got %q; expected %q", actual, expected)
+                       }
+
+                       // Wrap it up.
+                       close(done)
+               })
+               if err != nil {
+                       t.Fatal(err)
+               }
+       }()
+
+       // Wait for first client
+       <-done
+
+       // Second client
+       err = sshd.NewClientSession(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
+               scanner := bufio.NewScanner(r)
+
+               // Consume the initial buffer
+               scanner.Scan()
+               actual := scanner.Text()
+               if !strings.HasPrefix(actual, "[Guest1] ") {
+                       t.Errorf("Second client did not get Guest1 name.")
+               }
+       })
+       if err != nil {
+               t.Fatal(err)
        }
+
+       <-done
+       s.Close()
 }
diff --git a/sshd/client.go b/sshd/client.go
new file mode 100644 (file)
index 0000000..60dab6e
--- /dev/null
@@ -0,0 +1,65 @@
+package sshd
+
+import (
+       "crypto/rand"
+       "crypto/rsa"
+       "io"
+
+       "golang.org/x/crypto/ssh"
+)
+
+// NewRandomKey generates a random key of a desired bit length.
+func NewRandomKey(bits int) (ssh.Signer, error) {
+       key, err := rsa.GenerateKey(rand.Reader, bits)
+       if err != nil {
+               return nil, err
+       }
+       return ssh.NewSignerFromKey(key)
+}
+
+// NewClientConfig creates a barebones ssh.ClientConfig to be used with ssh.Dial.
+func NewClientConfig(name string) *ssh.ClientConfig {
+       return &ssh.ClientConfig{
+               User: name,
+               Auth: []ssh.AuthMethod{
+                       ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) (answers []string, err error) {
+                               return
+                       }),
+               },
+       }
+}
+
+// NewClientSession makes a barebones SSH client session, used for testing.
+func NewClientSession(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error {
+       config := NewClientConfig(name)
+       conn, err := ssh.Dial("tcp", host, config)
+       if err != nil {
+               return err
+       }
+       defer conn.Close()
+
+       session, err := conn.NewSession()
+       if err != nil {
+               return err
+       }
+       defer session.Close()
+
+       in, err := session.StdinPipe()
+       if err != nil {
+               return err
+       }
+
+       out, err := session.StdoutPipe()
+       if err != nil {
+               return err
+       }
+
+       err = session.Shell()
+       if err != nil {
+               return err
+       }
+
+       handler(out, in)
+
+       return nil
+}
index 6ec4311cedb746e371cc84d7a77e938cebc0ac36..8321b301049b60714435a88ee63d9e78d98d7c86 100644 (file)
@@ -2,66 +2,12 @@ package sshd
 
 import (
        "bytes"
-       "crypto/rand"
-       "crypto/rsa"
        "io"
        "testing"
-
-       "golang.org/x/crypto/ssh"
 )
 
 // TODO: Move some of these into their own package?
 
-func MakeKey(bits int) (ssh.Signer, error) {
-       key, err := rsa.GenerateKey(rand.Reader, bits)
-       if err != nil {
-               return nil, err
-       }
-       return ssh.NewSignerFromKey(key)
-}
-
-func NewClientSession(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error {
-       config := &ssh.ClientConfig{
-               User: name,
-               Auth: []ssh.AuthMethod{
-                       ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) (answers []string, err error) {
-                               return
-                       }),
-               },
-       }
-
-       conn, err := ssh.Dial("tcp", host, config)
-       if err != nil {
-               return err
-       }
-       defer conn.Close()
-
-       session, err := conn.NewSession()
-       if err != nil {
-               return err
-       }
-       defer session.Close()
-
-       in, err := session.StdinPipe()
-       if err != nil {
-               return err
-       }
-
-       out, err := session.StdoutPipe()
-       if err != nil {
-               return err
-       }
-
-       err = session.Shell()
-       if err != nil {
-               return err
-       }
-
-       handler(out, in)
-
-       return nil
-}
-
 func TestServerInit(t *testing.T) {
        config := MakeNoAuth()
        s, err := ListenSSH(":badport", config)
@@ -81,7 +27,7 @@ func TestServerInit(t *testing.T) {
 }
 
 func TestServeTerminals(t *testing.T) {
-       signer, err := MakeKey(512)
+       signer, err := NewRandomKey(512)
        config := MakeNoAuth()
        config.AddHostKey(signer)