Add Identity Provider to support Authentication via Authentik,OpenID ,etc..

This commit is contained in:
Mo Tarbin 2025-02-04 23:59:49 -05:00
parent 0647725c68
commit 430f46ffee
10 changed files with 269 additions and 30 deletions

View file

@ -18,6 +18,7 @@ type Config struct {
SchedulerJobs SchedulerConfig `mapstructure:"scheduler_jobs" yaml:"scheduler_jobs"`
EmailConfig EmailConfig `mapstructure:"email" yaml:"email"`
StripeConfig StripeConfig `mapstructure:"stripe" yaml:"stripe"`
OAuth2Config OAuth2Config `mapstructure:"oauth2" yaml:"oauth2"`
IsDoneTickDotCom bool `mapstructure:"is_done_tick_dot_com" yaml:"is_done_tick_dot_com"`
IsUserCreationDisabled bool `mapstructure:"is_user_creation_disabled" yaml:"is_user_creation_disabled"`
}
@ -84,6 +85,16 @@ type EmailConfig struct {
AppHost string `mapstructure:"appHost"`
}
type OAuth2Config struct {
ClientID string `mapstructure:"client_id" yaml:"client_id"`
ClientSecret string `mapstructure:"client_secret" yaml:"client_secret"`
RedirectURL string `mapstructure:"redirect_url" yaml:"redirect_url"`
Scopes []string
AuthURL string `mapstructure:"auth_url" yaml:"auth_url"`
TokenURL string `mapstructure:"token_url" yaml:"token_url"`
UserInfoURL string `mapstructure:"user_info_url" yaml:"user_info_url"`
}
func NewConfig() *Config {
return &Config{
Telegram: TelegramConfig{
@ -126,9 +137,12 @@ func LoadConfig() *Config {
}
// get logger and log the current environment:
fmt.Printf("--ConfigLoad config for environment: %s ", os.Getenv("DT_ENV"))
viper.SetEnvPrefix("DT")
viper.AutomaticEnv()
viper.AddConfigPath("./config")
viper.SetConfigType("yaml")
err := viper.ReadInConfig()
// print a useful error:
if err != nil {
@ -141,9 +155,11 @@ func LoadConfig() *Config {
panic(err)
}
fmt.Printf("--ConfigLoad name : %s ", config.Name)
viper.SetEnvPrefix("DT")
viper.AutomaticEnv()
// bind all the environment variables to the config:
configEnvironmentOverrides(&config)
panic(config.OAuth2Config.ClientID)
return &config
// return LocalConfig()

View file

@ -23,3 +23,9 @@ DT_EMAIL_PORT=
DT_EMAIL_KEY=
DT_EMAIL_EMAIL=
DT_EMAIL_APP_HOST=
DT_OAUTH2_CLIENT_ID=
DT_OAUTH2_CLIENT_SECRET=
DT_OAUTH2_AUTH_URL=
DT_OAUTH2_TOKEN_URL=
DT_OAUTH2_USER_INFO_URL=
DT_OAUTH2_REDIRECT_URL=

View file

@ -35,3 +35,10 @@ email:
key:
email:
appHost:
oauth2:
client_id:
client_secret:
auth_url:
token_url:
user_info_url:
redirect_url:

3
go.mod
View file

@ -10,6 +10,8 @@ require (
github.com/gin-gonic/gin v1.10.0
github.com/glebarez/sqlite v1.11.0
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/gregdel/pushover v1.3.1
github.com/rubenv/sql-migrate v1.7.0
github.com/spf13/viper v1.19.0
github.com/ulule/limiter/v3 v3.11.2
@ -50,7 +52,6 @@ require (
github.com/google/uuid v1.6.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.4 // indirect
github.com/gregdel/pushover v1.3.1 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect

2
go.sum
View file

@ -72,6 +72,8 @@ github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA=
github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg=
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=

View file

@ -0,0 +1,96 @@
package auth
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"donetick.com/core/config"
"golang.org/x/oauth2"
)
type IdentityProviderUserInfo struct {
Identifier string
DisplayName string
Email string
}
type IdentityProvider struct {
config *config.OAuth2Config
isEnabled bool
}
func NewIdentityProvider(cfg *config.Config) *IdentityProvider {
if cfg.OAuth2Config.ClientID == "" || cfg.OAuth2Config.ClientSecret == "" {
return &IdentityProvider{isEnabled: false}
}
return &IdentityProvider{config: &cfg.OAuth2Config, isEnabled: true}
}
func (i *IdentityProvider) ExchangeToken(ctx context.Context, code string) (string, error) {
if !i.isEnabled {
return "", errors.New("identity provider is not enabled")
}
conf := &oauth2.Config{
ClientID: i.config.ClientID,
ClientSecret: i.config.ClientSecret,
RedirectURL: i.config.RedirectURL,
Scopes: i.config.Scopes,
Endpoint: oauth2.Endpoint{
AuthURL: i.config.AuthURL,
TokenURL: i.config.TokenURL,
},
}
token, err := conf.Exchange(ctx, code)
if err != nil {
return "", err
}
accessToken, ok := token.Extra("access_token").(string)
if !ok {
return "", errors.New("access token not found")
}
return accessToken, nil
}
func (i *IdentityProvider) GetUserInfo(ctx context.Context, accessToken string) (*IdentityProviderUserInfo, error) {
req, err := http.NewRequest("GET", i.config.UserInfoURL, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Accept", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var claims map[string]any
err = json.Unmarshal(body, &claims)
if err != nil {
return nil, errors.New("failed to unmarshal claims")
}
userInfo := IdentityProviderUserInfo{}
if val, ok := claims["sub"]; ok {
userInfo.Identifier = val.(string)
}
if val, ok := claims["name"]; ok {
userInfo.DisplayName = val.(string)
}
if val, ok := claims["email"]; ok {
userInfo.Email = val.(string)
}
return &userInfo, nil
}

View file

@ -16,7 +16,7 @@ import (
"github.com/gin-gonic/gin"
)
type Webhook struct {
type API struct {
choreRepo *chRepo.ChoreRepository
circleRepo *cRepo.CircleRepository
thingRepo *tRepo.ThingRepository
@ -24,9 +24,9 @@ type Webhook struct {
tRepo *tRepo.ThingRepository
}
func NewWebhook(cr *chRepo.ChoreRepository, circleRepo *cRepo.CircleRepository,
thingRepo *tRepo.ThingRepository, userRepo *uRepo.UserRepository, tRepo *tRepo.ThingRepository) *Webhook {
return &Webhook{
func NewAPI(cr *chRepo.ChoreRepository, circleRepo *cRepo.CircleRepository,
thingRepo *tRepo.ThingRepository, userRepo *uRepo.UserRepository, tRepo *tRepo.ThingRepository) *API {
return &API{
choreRepo: cr,
circleRepo: circleRepo,
thingRepo: thingRepo,
@ -35,7 +35,7 @@ func NewWebhook(cr *chRepo.ChoreRepository, circleRepo *cRepo.CircleRepository,
}
}
func (h *Webhook) UpdateThingState(c *gin.Context) {
func (h *API) UpdateThingState(c *gin.Context) {
thing, shouldReturn := validateUserAndThing(c, h)
if shouldReturn {
return
@ -60,7 +60,7 @@ func (h *Webhook) UpdateThingState(c *gin.Context) {
c.JSON(200, gin.H{})
}
func (h *Webhook) ChangeThingState(c *gin.Context) {
func (h *API) ChangeThingState(c *gin.Context) {
thing, shouldReturn := validateUserAndThing(c, h)
if shouldReturn {
return
@ -109,7 +109,7 @@ func (h *Webhook) ChangeThingState(c *gin.Context) {
c.JSON(200, gin.H{"state": thing.State})
}
func WebhookEvaluateTriggerAndScheduleDueDate(h *Webhook, c *gin.Context, thing *tModel.Thing) bool {
func WebhookEvaluateTriggerAndScheduleDueDate(h *API, c *gin.Context, thing *tModel.Thing) bool {
// handler should be interface to not duplicate both WebhookEvaluateTriggerAndScheduleDueDate and EvaluateTriggerAndScheduleDueDate
// this is bad code written Saturday at 2:25 AM
@ -134,7 +134,7 @@ func WebhookEvaluateTriggerAndScheduleDueDate(h *Webhook, c *gin.Context, thing
return false
}
func validateUserAndThing(c *gin.Context, h *Webhook) (*tModel.Thing, bool) {
func validateUserAndThing(c *gin.Context, h *API) (*tModel.Thing, bool) {
apiToken := c.GetHeader("secretkey")
if apiToken == "" {
c.JSON(401, gin.H{"error": "Unauthorized"})
@ -162,7 +162,7 @@ func validateUserAndThing(c *gin.Context, h *Webhook) (*tModel.Thing, bool) {
return thing, false
}
func Webhooks(cfg *config.Config, w *Webhook, r *gin.Engine, auth *jwt.GinJWTMiddleware) {
func APIs(cfg *config.Config, w *API, r *gin.Engine, auth *jwt.GinJWTMiddleware) {
thingsAPI := r.Group("eapi/v1/things")

View file

@ -31,16 +31,18 @@ type Handler struct {
circleRepo *cRepo.CircleRepository
jwtAuth *jwt.GinJWTMiddleware
email *email.EmailSender
identityProvider *auth.IdentityProvider
isDonetickDotCom bool
IsUserCreationDisabled bool
}
func NewHandler(ur *uRepo.UserRepository, cr *cRepo.CircleRepository, jwtAuth *jwt.GinJWTMiddleware, email *email.EmailSender, config *config.Config) *Handler {
func NewHandler(ur *uRepo.UserRepository, cr *cRepo.CircleRepository, jwtAuth *jwt.GinJWTMiddleware, email *email.EmailSender, idp *auth.IdentityProvider, config *config.Config) *Handler {
return &Handler{
userRepo: ur,
circleRepo: cr,
jwtAuth: jwtAuth,
email: email,
identityProvider: idp,
isDonetickDotCom: config.IsDoneTickDotCom,
IsUserCreationDisabled: config.IsUserCreationDisabled,
}
@ -178,7 +180,8 @@ func (h *Handler) thirdPartyAuthCallback(c *gin.Context) {
provider := c.Param("provider")
logger.Infow("account.handler.thirdPartyAuthCallback", "provider", provider)
if provider == "google" {
switch provider {
case "google":
c.Set("auth_provider", "3rdPartyAuth")
type OAuthRequest struct {
Token string `json:"token" binding:"required"`
@ -219,7 +222,7 @@ func (h *Handler) thirdPartyAuthCallback(c *gin.Context) {
Image: userinfo.Picture,
Password: encodedPassword,
DisplayName: userinfo.GivenName,
Provider: 2,
Provider: uModel.AuthProviderGoogle,
}
createdUser, err := h.userRepo.CreateUser(c, acc)
if err != nil {
@ -278,6 +281,106 @@ func (h *Handler) thirdPartyAuthCallback(c *gin.Context) {
}
c.JSON(http.StatusOK, gin.H{"token": tokenString, "expire": expire})
return
case "oauth2":
c.Set("auth_provider", "3rdPartyAuth")
// Read the ID token from the request bod
type Request struct {
Code string `json:"code"`
}
var req Request
if err := c.BindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
return
}
token, err := h.identityProvider.ExchangeToken(c, req.Code)
if err != nil {
logger.Errorw("account.handler.thirdPartyAuthCallback (oauth2) failed to exchange token", "err", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to exchange token"})
return
}
claims, err := h.identityProvider.GetUserInfo(c, token)
if err != nil {
logger.Errorw("account.handler.thirdPartyAuthCallback (oauth2) failed to get claims", "err", err)
}
acc, err := h.userRepo.FindByEmail(c, claims.Email)
if err != nil {
// Create user
password := auth.GenerateRandomPassword(12)
encodedPassword, err := auth.EncodePassword(password)
if err != nil {
logger.Errorw("account.handler.thirdPartyAuthCallback (oauth2) password encoding failed", "err", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Password encoding failed"})
return
}
acc = &uModel.User{
Username: claims.Email,
Email: claims.Email,
Password: encodedPassword,
DisplayName: claims.DisplayName,
Provider: uModel.AuthProviderOAuth2,
}
createdUser, err := h.userRepo.CreateUser(c, acc)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Unable to create user",
})
return
}
// Create Circle for the user:
userCircle, err := h.circleRepo.CreateCircle(c, &cModel.Circle{
Name: claims.DisplayName + "'s circle",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
InviteCode: utils.GenerateInviteCode(c),
})
if err != nil {
c.JSON(500, gin.H{
"error": "Error creating circle",
})
return
}
if err := h.circleRepo.AddUserToCircle(c, &cModel.UserCircle{
UserID: createdUser.ID,
CircleID: userCircle.ID,
Role: "admin",
IsActive: true,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}); err != nil {
c.JSON(500, gin.H{
"error": "Error adding user to circle",
})
return
}
createdUser.CircleID = userCircle.ID
if err := h.userRepo.UpdateUser(c, createdUser); err != nil {
c.JSON(500, gin.H{
"error": "Error updating user",
})
return
}
}
// ... (JWT generation and response)
c.Set("user_account", acc)
h.jwtAuth.Authenticator(c)
tokenString, expire, err := h.jwtAuth.TokenGenerator(acc)
if err != nil {
logger.Errorw("Unable to Generate a Token")
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Unable to Generate a Token",
})
return
}
c.JSON(http.StatusOK, gin.H{"token": tokenString, "expire": expire})
return
default:
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider"})
return
}
}

View file

@ -11,7 +11,7 @@ type User struct {
DisplayName string `json:"displayName" gorm:"column:display_name"` // Display name
Username string `json:"username" gorm:"column:username;unique"` // Username (unique)
Email string `json:"email" gorm:"column:email;unique"` // Email (unique)
Provider int `json:"provider" gorm:"column:provider"` // Provider
Provider AuthProvider `json:"provider" gorm:"column:provider"` // Provider
Password string `json:"-" gorm:"column:password"` // Password
CircleID int `json:"circleID" gorm:"column:circle_id"` // Circle ID
ChatID int64 `json:"chatID" gorm:"column:chat_id"` // Telegram chat ID
@ -48,3 +48,10 @@ type UserNotificationTarget struct {
TargetID string `json:"target_id" gorm:"column:target_id"` // Target ID
CreatedAt time.Time `json:"-" gorm:"column:created_at"`
}
type AuthProvider int
const (
AuthProviderDonetick AuthProvider = iota
AuthProviderOAuth2
AuthProviderGoogle
)

View file

@ -52,6 +52,7 @@ func main() {
// fx.Provide(config.NewConfig),
fx.Provide(auth.NewAuthMiddleware),
fx.Provide(auth.NewIdentityProvider),
// fx.Provide(NewBot),
fx.Provide(database.NewDatabase),
@ -89,7 +90,7 @@ func main() {
fx.Provide(lRepo.NewLabelRepository),
fx.Provide(label.NewHandler),
fx.Provide(thing.NewWebhook),
fx.Provide(thing.NewAPI),
fx.Provide(thing.NewHandler),
fx.Provide(chore.NewAPI),
@ -103,7 +104,7 @@ func main() {
user.Routes,
circle.Routes,
thing.Routes,
thing.Webhooks,
thing.APIs,
label.Routes,
frontend.Routes,