1
0
Fork 0
httpserver/modules/saml/saml.go

217 Zeilen
5.6 KiB
Go

package saml
import (
"context"
"crypto/rsa"
"crypto/x509"
"errors"
"fmt"
"net/http"
"net/url"
"path"
"time"
"github.com/crewjam/saml"
"github.com/crewjam/saml/samlsp"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v4"
"github.com/phuslu/log"
"go.sebtobie.de/httpserver"
"go.sebtobie.de/httpserver/auth"
)
func musturi(url *url.URL, err error) *url.URL {
if err != nil {
panic(err)
}
return url
}
var (
defaultsaml = &SAML{
idp: musturi(url.ParseRequestURI("https://samltest.id/saml/idp")),
Domain: "example.com",
Cookiename: "ILOVECOOKIES",
}
_ httpserver.Site = defaultsaml
_ httpserver.ConfigSite = defaultsaml
)
type metadata struct{}
// SAML is an Applicance to react on Events from the SAML-IDP and that provides an interface to get data from the IDP in a standartised fashion.
type SAML struct {
router *gin.RouterGroup
publicroot string
SPPublicKey string
sppublickey *x509.Certificate
SPPrivatekey string
spprivatekey *rsa.PrivateKey
JWTPrivatekey string
jwtprivatekey *rsa.PrivateKey
idp *url.URL
sp *saml.ServiceProvider
HTTPClient http.Client `toml:"-"`
Domain string `toml:"domain"`
Cookiename string
}
// Setup sets the saml object up.
func (s *SAML) Setup(config httpserver.SiteConfig) (err error) {
log.Info().Msg("Setting up SAML service provider")
s.Domain = config["domain"].(string)
s.sp = &saml.ServiceProvider{
AcsURL: url.URL{
Scheme: "https",
Host: s.Domain,
Path: path.Join(s.publicroot, "acs"),
},
MetadataURL: url.URL{
Scheme: "https",
Host: s.Domain,
Path: path.Join(s.publicroot, "metadata.xml"),
},
}
switch config["metadatavalid"].(type) {
case time.Duration:
s.sp.MetadataValidDuration = config["metadatavalid"].(time.Duration)
case int:
s.sp.MetadataValidDuration = time.Duration(config["metadatavalid"].(int))
case int8:
s.sp.MetadataValidDuration = time.Duration(config["metadatavalid"].(int8))
case int16:
s.sp.MetadataValidDuration = time.Duration(config["metadatavalid"].(int16))
case int32:
s.sp.MetadataValidDuration = time.Duration(config["metadatavalid"].(int32))
case int64:
s.sp.MetadataValidDuration = time.Duration(config["metadatavalid"].(int64))
}
var key interface{}
if keyfile, found := config["spprivatekey"]; found {
s.SPPrivatekey = keyfile.(string)
key, err = initcert(s.SPPrivatekey, func(key interface{}) bool {
_, ok := key.(*rsa.PrivateKey)
return ok
})
if err != nil {
return
}
s.spprivatekey = key.(*rsa.PrivateKey)
} else {
return errors.New("SP Privatekey not found")
}
if keyfile, found := config["sppublickey"]; found {
s.SPPublicKey = keyfile.(string)
key, err = initcert(s.SPPublicKey, func(key interface{}) bool {
_, ok := key.(*x509.Certificate)
return ok
})
if err != nil {
return
}
s.sppublickey = key.(*x509.Certificate)
} else {
return errors.New("SP Publickey not found")
}
if keyfile, found := config["jwtprivatekey"]; found {
s.JWTPrivatekey = keyfile.(string)
key, err = initcert(s.JWTPrivatekey, func(key interface{}) bool {
_, ok := key.(*rsa.PrivateKey)
return ok
})
if err != nil {
return
}
s.jwtprivatekey = key.(*rsa.PrivateKey)
} else {
return errors.New("JWT Privatekey not found")
}
s.sp = &saml.ServiceProvider{
Key: s.spprivatekey,
Certificate: s.sppublickey,
AuthnNameIDFormat: saml.PersistentNameIDFormat,
}
if idp, found := config["idp"]; found {
s.idp, err = url.ParseRequestURI(idp.(string))
if err != nil {
return
}
s.sp.IDPMetadata, err = samlsp.FetchMetadata(context.Background(), &s.HTTPClient, *s.idp)
if err != nil {
return
}
} else {
err = errors.New("IDP in configfile not found")
}
return
}
// Init initalizes the routes
func (s *SAML) Init(router *gin.RouterGroup) {
s.publicroot = router.BasePath()
s.router = router
router.GET("metadata.xml", s.metadataHF)
router.POST("acs", s.acsHF)
}
func (s *SAML) metadataHF(c *gin.Context) {
if s.sp == nil {
c.AbortWithStatus(500)
return
}
m := s.sp.Metadata()
if m == nil {
c.AbortWithStatus(500)
return
}
c.XML(http.StatusOK, m)
}
func (s *SAML) acsHF(c *gin.Context) {
account := c.MustGet("account").(auth.Account)
err := c.Request.ParseForm()
if err != nil {
c.AbortWithError(http.StatusNotAcceptable, err)
}
var assert *saml.Assertion
assert, err = s.sp.ParseResponse(c.Request, []string{account.Get("jti").(string)})
if err != nil {
realerr, _ := err.(*saml.InvalidResponseError)
err = realerr.PrivateErr
log.Error().AnErr("Assertionerror", err).Msgf("Assertion Error")
fmt.Print(realerr.Response)
c.AbortWithStatus(http.StatusBadRequest)
return
}
data := attributeStatementstomap(assert.AttributeStatements)
token, err := jwttoken(jwt.MapClaims{
string(auth.AccountAnon): false,
string(auth.AccountID): account.Get(auth.AccountID).(string),
string(auth.AccountUser): data["uid"][0],
}, s.jwtprivatekey)
if err != nil {
c.AbortWithStatus(http.StatusInternalServerError)
return
}
c.SetCookie(s.Cookiename, token, int(time.Hour*24*30), "", "", true, true)
redirect, found := c.GetPostForm("RelayState")
if !found {
c.AbortWithStatus(http.StatusNotAcceptable)
return
}
c.Redirect(http.StatusSeeOther, redirect)
}
// Defaults returns the default values for the config
func (s *SAML) Defaults() httpserver.SiteConfig {
return map[string]interface{}{
"domain": "example.com",
"idp": defaultsaml.idp.String(),
"sppublickey": "publickey.pem",
"spprivatekey": "privatekey.pem",
"jwtprivatekey": "privatekey.pem",
"metadatavalid": time.Duration(time.Hour * 24),
}
}