Channel Member now wrapping User with metadata, new Auth struct.
authorAndrey Petrov <andrey.petrov@shazow.net>
Fri, 2 Jan 2015 02:40:10 +0000 (18:40 -0800)
committerAndrey Petrov <andrey.petrov@shazow.net>
Fri, 2 Jan 2015 02:40:10 +0000 (18:40 -0800)
chat/channel.go
chat/command.go
cmd.go
host.go
sshd/auth.go
sshd/terminal.go

index 3bdca403d77b7aae78b7bac79f53161d788685a3..b4de26e1543273466f9061c749db043eb0243b45 100644 (file)
@@ -13,12 +13,17 @@ const channelBuffer = 10
 // closed.
 var ErrChannelClosed = errors.New("channel closed")
 
+// Member is a User with per-Channel metadata attached to it.
+type Member struct {
+       *User
+       Op bool
+}
+
 // Channel definition, also a Set of User Items
 type Channel struct {
        topic     string
        history   *History
-       users     *Set
-       ops       *Set
+       members   *Set
        broadcast chan Message
        commands  Commands
        closed    bool
@@ -32,8 +37,7 @@ func NewChannel() *Channel {
        return &Channel{
                broadcast: broadcast,
                history:   NewHistory(historyLen),
-               users:     NewSet(),
-               ops:       NewSet(),
+               members:   NewSet(),
                commands:  *defaultCommands,
        }
 }
@@ -47,10 +51,10 @@ func (ch *Channel) SetCommands(commands Commands) {
 func (ch *Channel) Close() {
        ch.closeOnce.Do(func() {
                ch.closed = true
-               ch.users.Each(func(u Item) {
+               ch.members.Each(func(u Item) {
                        u.(*User).Close()
                })
-               ch.users.Clear()
+               ch.members.Clear()
                close(ch.broadcast)
        })
 }
@@ -75,8 +79,8 @@ func (ch *Channel) HandleMsg(m Message) {
                        skipUser = fromMsg.From()
                }
 
-               ch.users.Each(func(u Item) {
-                       user := u.(*User)
+               ch.members.Each(func(u Item) {
+                       user := u.(*Member).User
                        if skip && skipUser == user {
                                // Skip
                                return
@@ -108,18 +112,18 @@ func (ch *Channel) Join(u *User) error {
        if ch.closed {
                return ErrChannelClosed
        }
-       err := ch.users.Add(u)
+       err := ch.members.Add(&Member{u, false})
        if err != nil {
                return err
        }
-       s := fmt.Sprintf("%s joined. (Connected: %d)", u.Name(), ch.users.Len())
+       s := fmt.Sprintf("%s joined. (Connected: %d)", u.Name(), ch.members.Len())
        ch.Send(NewAnnounceMsg(s))
        return nil
 }
 
-// Leave the channel as a user, will announce.
+// Leave the channel as a user, will announce. Mostly used during setup.
 func (ch *Channel) Leave(u *User) error {
-       err := ch.users.Remove(u)
+       err := ch.members.Remove(u)
        if err != nil {
                return err
        }
@@ -128,6 +132,26 @@ func (ch *Channel) Leave(u *User) error {
        return nil
 }
 
+// Member returns a corresponding Member object to a User if the Member is
+// present in this channel.
+func (ch *Channel) Member(u *User) (*Member, bool) {
+       m, err := ch.members.Get(u.Id())
+       if err != nil {
+               return nil, false
+       }
+       // Check that it's the same user
+       if m.(*Member).User != u {
+               return nil, false
+       }
+       return m.(*Member), true
+}
+
+// IsOp returns whether a user is an operator in this channel.
+func (ch *Channel) IsOp(u *User) bool {
+       m, ok := ch.Member(u)
+       return ok && m.Op
+}
+
 // Topic of the channel.
 func (ch *Channel) Topic() string {
        return ch.topic
@@ -141,9 +165,9 @@ func (ch *Channel) SetTopic(s string) {
 // NamesPrefix lists all members' names with a given prefix, used to query
 // for autocompletion purposes.
 func (ch *Channel) NamesPrefix(prefix string) []string {
-       users := ch.users.ListPrefix(prefix)
-       names := make([]string, len(users))
-       for i, u := range users {
+       members := ch.members.ListPrefix(prefix)
+       names := make([]string, len(members))
+       for i, u := range members {
                names[i] = u.(*User).Name()
        }
        return names
index 50d50e5637bf366ce37c45919dc1405ed358d768..431ecb9be650039536efe45f9842866723b76348 100644 (file)
@@ -98,9 +98,8 @@ func init() {
        c.Add(Command{
                Prefix: "/help",
                Handler: func(channel *Channel, msg CommandMsg) error {
-                       user := msg.From()
-                       op := channel.ops.In(user)
-                       channel.Send(NewSystemMsg(channel.commands.Help(op), user))
+                       op := channel.IsOp(msg.From())
+                       channel.Send(NewSystemMsg(channel.commands.Help(op), msg.From()))
                        return nil
                },
        })
@@ -193,11 +192,12 @@ func init() {
        })
 
        c.Add(Command{
+               Op:         true,
                Prefix:     "/op",
                PrefixHelp: "USER",
                Help:       "Mark user as admin.",
                Handler: func(channel *Channel, msg CommandMsg) error {
-                       if !channel.ops.In(msg.From()) {
+                       if !channel.IsOp(msg.From()) {
                                return errors.New("must be op")
                        }
 
@@ -206,13 +206,14 @@ func init() {
                                return errors.New("must specify user")
                        }
 
-                       // TODO: Add support for fingerprint-based op'ing.
-                       user, err := channel.users.Get(Id(args[0]))
+                       // TODO: Add support for fingerprint-based op'ing. This will
+                       // probably need to live in host land.
+                       member, err := channel.members.Get(Id(args[0]))
                        if err != nil {
                                return errors.New("user not found")
                        }
 
-                       channel.ops.Add(user)
+                       member.(*Member).Op = true
                        return nil
                },
        })
diff --git a/cmd.go b/cmd.go
index 8c8a576910742873b99ef15d8b87b32eae07f6ee..e7bfc68a051a4e30231532366294b0b90a31a2b5 100644 (file)
--- a/cmd.go
+++ b/cmd.go
@@ -1,6 +1,7 @@
 package main
 
 import (
+       "bufio"
        "fmt"
        "io/ioutil"
        "net/http"
@@ -93,8 +94,8 @@ func main() {
                return
        }
 
-       // TODO: MakeAuth
-       config := sshd.MakeNoAuth()
+       auth := Auth{}
+       config := sshd.MakeAuth(auth)
        config.AddHostKey(signer)
 
        s, err := sshd.ListenSSH(options.Bind, config)
@@ -106,11 +107,10 @@ func main() {
        defer s.Close()
 
        host := NewHost(s)
-       go host.Serve()
+       host.auth = &auth
 
-       /* TODO:
        for _, fingerprint := range options.Admin {
-               server.Op(fingerprint)
+               auth.Op(fingerprint)
        }
 
        if options.Whitelist != "" {
@@ -123,7 +123,7 @@ func main() {
 
                scanner := bufio.NewScanner(file)
                for scanner.Scan() {
-                       server.Whitelist(scanner.Text())
+                       auth.Whitelist(scanner.Text())
                }
        }
 
@@ -137,9 +137,10 @@ func main() {
                // hack to normalize line endings into \r\n
                motdString = strings.Replace(motdString, "\r\n", "\n", -1)
                motdString = strings.Replace(motdString, "\n", "\r\n", -1)
-               server.SetMotd(motdString)
+               host.SetMotd(motdString)
        }
-       */
+
+       go host.Serve()
 
        // Construct interrupt handler
        sig := make(chan os.Signal, 1)
diff --git a/host.go b/host.go
index 680d553c3ed8d1d4c2e5d006708af9f445e9ac7f..c6c8de72544c6c30f99746f21e50f0a397ac7af3 100644 (file)
--- a/host.go
+++ b/host.go
@@ -14,9 +14,12 @@ import (
 type Host struct {
        listener *sshd.SSHListener
        channel  *chat.Channel
+
+       motd string
+       auth *Auth
 }
 
-// NewHost creates a Host on top of an existing listener
+// NewHost creates a Host on top of an existing listener.
 func NewHost(listener *sshd.SSHListener) *Host {
        ch := chat.NewChannel()
        h := Host{
@@ -27,7 +30,12 @@ func NewHost(listener *sshd.SSHListener) *Host {
        return &h
 }
 
-// Connect a specific Terminal to this host and its channel
+// SetMotd sets the host's message of the day.
+func (h *Host) SetMotd(motd string) {
+       h.motd = motd
+}
+
+// Connect a specific Terminal to this host and its channel.
 func (h *Host) Connect(term *sshd.Terminal) {
        name := term.Conn.User()
        term.AutoCompleteCallback = h.AutoCompleteFunction
index d271a85545ea513b7a268d05f629ec6a921818ab..90134e5c88a3108d48995703b1a77160279d2715 100644 (file)
@@ -9,13 +9,9 @@ import (
        "golang.org/x/crypto/ssh"
 )
 
-var errBanned = errors.New("banned")
-var errNotWhitelisted = errors.New("not whitelisted")
-var errNoInteractive = errors.New("public key authentication required")
-
 type Auth interface {
-       IsBanned(ssh.PublicKey) bool
-       IsWhitelisted(ssh.PublicKey) bool
+       AllowAnonymous() bool
+       Check(string) (bool, error)
 }
 
 func MakeAuth(auth Auth) *ssh.ServerConfig {
@@ -23,21 +19,17 @@ func MakeAuth(auth Auth) *ssh.ServerConfig {
                NoClientAuth: false,
                // Auth-related things should be constant-time to avoid timing attacks.
                PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
-                       if auth.IsBanned(key) {
-                               return nil, errBanned
-                       }
-                       if !auth.IsWhitelisted(key) {
-                               return nil, errNotWhitelisted
+                       fingerprint := Fingerprint(key)
+                       ok, err := auth.Check(fingerprint)
+                       if !ok {
+                               return nil, err
                        }
-                       perm := &ssh.Permissions{Extensions: map[string]string{"fingerprint": Fingerprint(key)}}
+                       perm := &ssh.Permissions{Extensions: map[string]string{"fingerprint": fingerprint}}
                        return perm, nil
                },
                KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
-                       if auth.IsBanned(nil) {
-                               return nil, errNoInteractive
-                       }
-                       if !auth.IsWhitelisted(nil) {
-                               return nil, errNotWhitelisted
+                       if !auth.AllowAnonymous() {
+                               return nil, errors.New("public key authentication required")
                        }
                        return nil, nil
                },
@@ -51,7 +43,8 @@ func MakeNoAuth() *ssh.ServerConfig {
                NoClientAuth: false,
                // Auth-related things should be constant-time to avoid timing attacks.
                PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
-                       return nil, nil
+                       perm := &ssh.Permissions{Extensions: map[string]string{"fingerprint": Fingerprint(key)}}
+                       return perm, nil
                },
                KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
                        return nil, nil
index 14318b3c97bd38d9f2abcfcc0d8de893b91cdb0e..196b9bea1737357d128965d5da0075d678f58979 100644 (file)
@@ -11,12 +11,12 @@ import (
 // Extending ssh/terminal to include a closer interface
 type Terminal struct {
        terminal.Terminal
-       Conn    ssh.Conn
+       Conn    *ssh.ServerConn
        Channel ssh.Channel
 }
 
 // Make new terminal from a session channel
-func NewTerminal(conn ssh.Conn, ch ssh.NewChannel) (*Terminal, error) {
+func NewTerminal(conn *ssh.ServerConn, ch ssh.NewChannel) (*Terminal, error) {
        if ch.ChannelType() != "session" {
                return nil, errors.New("terminal requires session channel")
        }
@@ -41,7 +41,7 @@ func NewTerminal(conn ssh.Conn, ch ssh.NewChannel) (*Terminal, error) {
 }
 
 // Find session channel and make a Terminal from it
-func NewSession(conn ssh.Conn, channels <-chan ssh.NewChannel) (term *Terminal, err error) {
+func NewSession(conn *ssh.ServerConn, channels <-chan ssh.NewChannel) (term *Terminal, err error) {
        for ch := range channels {
                if t := ch.ChannelType(); t != "session" {
                        ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t))