|
@@ -30,6 +30,7 @@ import (
|
|
|
"bytes"
|
|
|
"encoding/base64"
|
|
|
"encoding/hex"
|
|
|
+ "errors"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"log"
|
|
@@ -38,12 +39,14 @@ import (
|
|
|
"strconv"
|
|
|
"strings"
|
|
|
|
|
|
+ "git.semlanik.org/semlanik/gostfix/auth"
|
|
|
"github.com/google/uuid"
|
|
|
)
|
|
|
|
|
|
type SaslServer struct {
|
|
|
- pid int
|
|
|
- cuid int
|
|
|
+ pid int
|
|
|
+ cuid int
|
|
|
+ authenticator *auth.Authenticator
|
|
|
}
|
|
|
|
|
|
const (
|
|
@@ -67,8 +70,9 @@ const (
|
|
|
|
|
|
func NewSaslServer() *SaslServer {
|
|
|
return &SaslServer{
|
|
|
- pid: os.Getpid(),
|
|
|
- cuid: 0,
|
|
|
+ pid: os.Getpid(),
|
|
|
+ cuid: 0,
|
|
|
+ authenticator: auth.NewAuthenticator(),
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -103,6 +107,8 @@ func (s *SaslServer) handleRequest(conn net.Conn) {
|
|
|
|
|
|
if err == io.EOF {
|
|
|
break
|
|
|
+ // time.Sleep(100 * time.Millisecond)
|
|
|
+ // continue
|
|
|
}
|
|
|
|
|
|
if err != nil {
|
|
@@ -110,14 +116,20 @@ func (s *SaslServer) handleRequest(conn net.Conn) {
|
|
|
}
|
|
|
|
|
|
currentMessage := fullbuf
|
|
|
- if strings.Index(currentMessage, Version) == 0 {
|
|
|
- versionIds := strings.Split(currentMessage, "\t")
|
|
|
+ fmt.Printf("SASL: %s\n", fullbuf)
|
|
|
|
|
|
- if len(versionIds) < 3 {
|
|
|
+ ids := strings.Split(currentMessage, "\t")
|
|
|
+ if len(ids) < 2 {
|
|
|
+ break
|
|
|
+ }
|
|
|
+
|
|
|
+ switch ids[0] {
|
|
|
+ case Version:
|
|
|
+ if len(ids) < 3 {
|
|
|
break
|
|
|
}
|
|
|
|
|
|
- if major, err := strconv.Atoi(versionIds[1]); err != nil || major != 1 {
|
|
|
+ if major, err := strconv.Atoi(ids[1]); err != nil || major != 1 {
|
|
|
break
|
|
|
}
|
|
|
|
|
@@ -131,53 +143,71 @@ func (s *SaslServer) handleRequest(conn net.Conn) {
|
|
|
|
|
|
fmt.Fprintf(conn, "%s\t%s\n", Cookie, hex.EncodeToString(cookieUuid[:]))
|
|
|
fmt.Fprintf(conn, "%s\n", Done)
|
|
|
- } else if strings.Index(currentMessage, Auth) == 0 {
|
|
|
- authIds := strings.Split(currentMessage, "\t")
|
|
|
- if len(authIds) < 2 {
|
|
|
- break
|
|
|
+ case Auth:
|
|
|
+ for _, authId := range ids {
|
|
|
+ if strings.Index(authId, "resp=") == 0 {
|
|
|
+ login, err := s.checkCredentials(authId[5:])
|
|
|
+ if err != nil {
|
|
|
+ fmt.Fprintf(conn, "%s\t%s\treason=%s\n", Fail, ids[1], err.Error())
|
|
|
+ } else {
|
|
|
+ fmt.Fprintf(conn, "%s\t%s\tuser=%s\n", Ok, ids[1], login)
|
|
|
+ }
|
|
|
+ continueState = ContinueStateNone
|
|
|
+ return
|
|
|
+ }
|
|
|
}
|
|
|
- fmt.Fprintf(conn, "%s\t%s\t%s\n", Cont, authIds[1], base64.StdEncoding.EncodeToString([]byte("Username:")))
|
|
|
+
|
|
|
+ fmt.Fprintf(conn, "%s\t%s\t%s\n", Cont, ids[1], base64.StdEncoding.EncodeToString([]byte("Username:")))
|
|
|
continueState = ContinueStateCredentials
|
|
|
- } else if strings.Index(currentMessage, Cont) == 0 {
|
|
|
- contIds := strings.Split(currentMessage, "\t")
|
|
|
- if len(contIds) < 2 {
|
|
|
+ case Cont:
|
|
|
+ if len(ids) < 2 {
|
|
|
break
|
|
|
}
|
|
|
|
|
|
if continueState == ContinueStateCredentials {
|
|
|
- if len(contIds) < 3 {
|
|
|
- fmt.Fprintf(conn, "%s\t%s\treason=%s\n", Fail, contIds[1], "invalid base64 data")
|
|
|
+ if len(ids) < 3 {
|
|
|
+ fmt.Fprintf(conn, "%s\t%s\treason=%s\n", Fail, ids[1], "invalid base64 data")
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- credentials, err := base64.StdEncoding.DecodeString(contIds[2])
|
|
|
+ login, err := s.checkCredentials(ids[2])
|
|
|
if err != nil {
|
|
|
- fmt.Fprintf(conn, "%s\t%s\treason=%s\n", Fail, contIds[1], "invalid base64 data")
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- credentialList := bytes.Split(credentials, []byte{0})
|
|
|
- if len(credentialList) < 3 {
|
|
|
- fmt.Fprintf(conn, "%s\t%s\treason=%s\n", Fail, contIds[1], "invalid user or password")
|
|
|
- return
|
|
|
+ fmt.Fprintf(conn, "%s\t%s\treason=%s\n", Fail, ids[1], err.Error())
|
|
|
+ } else {
|
|
|
+ fmt.Fprintf(conn, "%s\t%s\tuser=%s\n", Ok, ids[1], login)
|
|
|
}
|
|
|
-
|
|
|
- // identity := string(credentialList[0])
|
|
|
- login := string(credentialList[1])
|
|
|
- // password := string(credentialList[2])
|
|
|
- //TODO: Use auth here
|
|
|
- // if login != "semlanik@semlanik.org" || password != "test" {
|
|
|
- if true {
|
|
|
- fmt.Fprintf(conn, "%s\t%s\treason=%s\n", Fail, contIds[1], "invalid user or password")
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- fmt.Fprintf(conn, "%s\t%s\tuser=%s\n", Ok, contIds[1], login)
|
|
|
continueState = ContinueStateNone
|
|
|
} else {
|
|
|
- fmt.Fprintf(conn, "%s\t%s\treason=%s\n", Fail, contIds[1], "invalid user or password")
|
|
|
+ fmt.Fprintf(conn, "%s\t%s\treason=%s\n", Fail, ids[1], "invalid user or password")
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
conn.Close()
|
|
|
}
|
|
|
+
|
|
|
+func (s *SaslServer) checkCredentials(credentialsBase64 string) (string, error) {
|
|
|
+ credentials, err := base64.StdEncoding.DecodeString(credentialsBase64)
|
|
|
+ if err != nil {
|
|
|
+ return "", errors.New("invalid base64 data")
|
|
|
+ }
|
|
|
+
|
|
|
+ credentialList := bytes.Split(credentials, []byte{0})
|
|
|
+ if len(credentialList) < 3 {
|
|
|
+ return "", errors.New("invalid user or password")
|
|
|
+ }
|
|
|
+
|
|
|
+ identity := string(credentialList[0])
|
|
|
+ login := string(credentialList[1])
|
|
|
+ password := string(credentialList[2])
|
|
|
+ if identity == "token" {
|
|
|
+ if s.authenticator.Verify(login, password) {
|
|
|
+ return login, nil
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ if _, ok := s.authenticator.Authenticate(login, password); ok {
|
|
|
+ return login, nil
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return "", errors.New("invalid user or password")
|
|
|
+}
|