218 Zeilen
5.6 KiB
Go
218 Zeilen
5.6 KiB
Go
package saml
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rsa"
|
|
"crypto/x509"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"path"
|
|
"time"
|
|
|
|
"github.com/crewjam/saml"
|
|
"github.com/crewjam/saml/samlsp"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/golang-jwt/jwt/v4"
|
|
"github.com/rs/zerolog/log"
|
|
"go.sebtobie.de/httpserver"
|
|
"go.sebtobie.de/httpserver/auth"
|
|
"go.sebtobie.de/httpserver/constants"
|
|
)
|
|
|
|
func musturi(url *url.URL, err error) *url.URL {
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return url
|
|
}
|
|
|
|
var (
|
|
defaultsaml = &SAML{
|
|
idp: musturi(url.ParseRequestURI("https://samltest.id/saml/idp")),
|
|
Domain: "example.com",
|
|
Cookiename: "ILOVECOOKIES",
|
|
}
|
|
_ httpserver.Site = defaultsaml
|
|
_ httpserver.ConfigSite = defaultsaml
|
|
)
|
|
|
|
type metadata struct{}
|
|
|
|
// SAML is an Applicance to react on Events from the SAML-IDP and that provides an interface to get data from the IDP in a standartised fashion.
|
|
type SAML struct {
|
|
router *gin.RouterGroup
|
|
publicroot string
|
|
SPPublicKey string
|
|
sppublickey *x509.Certificate
|
|
SPPrivatekey string
|
|
spprivatekey *rsa.PrivateKey
|
|
JWTPrivatekey string
|
|
jwtprivatekey *rsa.PrivateKey
|
|
idp *url.URL
|
|
sp *saml.ServiceProvider
|
|
HTTPClient http.Client `toml:"-"`
|
|
Domain string `toml:"domain"`
|
|
Cookiename string
|
|
}
|
|
|
|
// Setup sets the saml object up.
|
|
func (s *SAML) Setup(config httpserver.SiteConfig) (err error) {
|
|
log.Info().Msg("Setting up SAML service provider")
|
|
s.Domain = config["domain"].(string)
|
|
s.sp = &saml.ServiceProvider{
|
|
AcsURL: url.URL{
|
|
Scheme: "https",
|
|
Host: s.Domain,
|
|
Path: path.Join(s.publicroot, "acs"),
|
|
},
|
|
MetadataURL: url.URL{
|
|
Scheme: "https",
|
|
Host: s.Domain,
|
|
Path: path.Join(s.publicroot, "metadata.xml"),
|
|
},
|
|
}
|
|
switch config["metadatavalid"].(type) {
|
|
case time.Duration:
|
|
s.sp.MetadataValidDuration = config["metadatavalid"].(time.Duration)
|
|
case int:
|
|
s.sp.MetadataValidDuration = time.Duration(config["metadatavalid"].(int))
|
|
case int8:
|
|
s.sp.MetadataValidDuration = time.Duration(config["metadatavalid"].(int8))
|
|
case int16:
|
|
s.sp.MetadataValidDuration = time.Duration(config["metadatavalid"].(int16))
|
|
case int32:
|
|
s.sp.MetadataValidDuration = time.Duration(config["metadatavalid"].(int32))
|
|
case int64:
|
|
s.sp.MetadataValidDuration = time.Duration(config["metadatavalid"].(int64))
|
|
}
|
|
var key any
|
|
if keyfile, found := config["spprivatekey"]; found {
|
|
s.SPPrivatekey = keyfile.(string)
|
|
key, err = initcert(s.SPPrivatekey, func(key any) bool {
|
|
_, ok := key.(*rsa.PrivateKey)
|
|
return ok
|
|
})
|
|
if err != nil {
|
|
return
|
|
}
|
|
s.spprivatekey = key.(*rsa.PrivateKey)
|
|
} else {
|
|
return errors.New("SP Privatekey not found")
|
|
}
|
|
if keyfile, found := config["sppublickey"]; found {
|
|
s.SPPublicKey = keyfile.(string)
|
|
key, err = initcert(s.SPPublicKey, func(key any) bool {
|
|
_, ok := key.(*x509.Certificate)
|
|
return ok
|
|
})
|
|
if err != nil {
|
|
return
|
|
}
|
|
s.sppublickey = key.(*x509.Certificate)
|
|
} else {
|
|
return errors.New("SP Publickey not found")
|
|
}
|
|
|
|
if keyfile, found := config["jwtprivatekey"]; found {
|
|
s.JWTPrivatekey = keyfile.(string)
|
|
key, err = initcert(s.JWTPrivatekey, func(key any) bool {
|
|
_, ok := key.(*rsa.PrivateKey)
|
|
return ok
|
|
})
|
|
if err != nil {
|
|
return
|
|
}
|
|
s.jwtprivatekey = key.(*rsa.PrivateKey)
|
|
} else {
|
|
return errors.New("JWT Privatekey not found")
|
|
}
|
|
s.sp = &saml.ServiceProvider{
|
|
Key: s.spprivatekey,
|
|
Certificate: s.sppublickey,
|
|
AuthnNameIDFormat: saml.PersistentNameIDFormat,
|
|
}
|
|
if idp, found := config["idp"]; found {
|
|
s.idp, err = url.ParseRequestURI(idp.(string))
|
|
if err != nil {
|
|
return
|
|
}
|
|
s.sp.IDPMetadata, err = samlsp.FetchMetadata(context.Background(), &s.HTTPClient, *s.idp)
|
|
if err != nil {
|
|
return
|
|
}
|
|
} else {
|
|
err = errors.New("IDP in configfile not found")
|
|
}
|
|
return
|
|
}
|
|
|
|
// Init initalizes the routes
|
|
func (s *SAML) Init(router *gin.RouterGroup) {
|
|
s.publicroot = router.BasePath()
|
|
s.router = router
|
|
router.GET("metadata.xml", s.metadataHF)
|
|
router.POST("acs", s.acsHF)
|
|
}
|
|
|
|
func (s *SAML) metadataHF(c *gin.Context) {
|
|
if s.sp == nil {
|
|
c.AbortWithStatus(500)
|
|
return
|
|
}
|
|
m := s.sp.Metadata()
|
|
if m == nil {
|
|
c.AbortWithStatus(500)
|
|
return
|
|
}
|
|
c.XML(http.StatusOK, m)
|
|
}
|
|
|
|
func (s *SAML) acsHF(c *gin.Context) {
|
|
account := c.MustGet("account").(auth.Account)
|
|
err := c.Request.ParseForm()
|
|
if err != nil {
|
|
c.AbortWithError(http.StatusNotAcceptable, err)
|
|
}
|
|
var assert *saml.Assertion
|
|
assert, err = s.sp.ParseResponse(c.Request, []string{account.Get("jti").(string)})
|
|
if err != nil {
|
|
realerr, _ := err.(*saml.InvalidResponseError)
|
|
err = realerr.PrivateErr
|
|
log.Error().AnErr("Assertionerror", err).Msgf("Assertion Error")
|
|
fmt.Print(realerr.Response)
|
|
c.AbortWithStatus(http.StatusBadRequest)
|
|
return
|
|
}
|
|
data := attributeStatementstomap(assert.AttributeStatements)
|
|
token, err := jwttoken(jwt.MapClaims{
|
|
string(constants.AccountAnon): false,
|
|
string(constants.AccountID): account.Get(constants.AccountID).(string),
|
|
string(constants.AccountUser): data["uid"][0],
|
|
}, s.jwtprivatekey)
|
|
if err != nil {
|
|
c.AbortWithStatus(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
c.SetCookie(s.Cookiename, token, int(time.Hour*24*30), "", "", true, true)
|
|
redirect, found := c.GetPostForm("RelayState")
|
|
if !found {
|
|
c.AbortWithStatus(http.StatusNotAcceptable)
|
|
return
|
|
}
|
|
c.Redirect(http.StatusSeeOther, redirect)
|
|
}
|
|
|
|
// Defaults returns the default values for the config
|
|
func (s *SAML) Defaults() httpserver.SiteConfig {
|
|
return map[string]any{
|
|
"domain": "example.com",
|
|
"idp": defaultsaml.idp.String(),
|
|
"sppublickey": "publickey.pem",
|
|
"spprivatekey": "privatekey.pem",
|
|
"jwtprivatekey": "privatekey.pem",
|
|
"metadatavalid": time.Duration(time.Hour * 24),
|
|
}
|
|
}
|