From 430f46ffee034f0ecb8b22deb1cf846a0bfb4cab Mon Sep 17 00:00:00 2001 From: Mo Tarbin Date: Tue, 4 Feb 2025 23:59:49 -0500 Subject: [PATCH] Add Identity Provider to support Authentication via Authentik,OpenID ,etc.. --- config/config.go | 20 +++- config/selfhosted.env | 8 +- config/selfhosted.yaml | 7 ++ go.mod | 3 +- go.sum | 2 + internal/authorization/identity_provider.go | 96 +++++++++++++++++ internal/thing/{webhook.go => api.go} | 18 ++-- internal/user/handler.go | 109 +++++++++++++++++++- internal/user/model/model.go | 31 +++--- main.go | 5 +- 10 files changed, 269 insertions(+), 30 deletions(-) create mode 100644 internal/authorization/identity_provider.go rename internal/thing/{webhook.go => api.go} (87%) diff --git a/config/config.go b/config/config.go index eccf204..c91664c 100644 --- a/config/config.go +++ b/config/config.go @@ -18,6 +18,7 @@ type Config struct { SchedulerJobs SchedulerConfig `mapstructure:"scheduler_jobs" yaml:"scheduler_jobs"` EmailConfig EmailConfig `mapstructure:"email" yaml:"email"` StripeConfig StripeConfig `mapstructure:"stripe" yaml:"stripe"` + OAuth2Config OAuth2Config `mapstructure:"oauth2" yaml:"oauth2"` IsDoneTickDotCom bool `mapstructure:"is_done_tick_dot_com" yaml:"is_done_tick_dot_com"` IsUserCreationDisabled bool `mapstructure:"is_user_creation_disabled" yaml:"is_user_creation_disabled"` } @@ -84,6 +85,16 @@ type EmailConfig struct { AppHost string `mapstructure:"appHost"` } +type OAuth2Config struct { + ClientID string `mapstructure:"client_id" yaml:"client_id"` + ClientSecret string `mapstructure:"client_secret" yaml:"client_secret"` + RedirectURL string `mapstructure:"redirect_url" yaml:"redirect_url"` + Scopes []string + AuthURL string `mapstructure:"auth_url" yaml:"auth_url"` + TokenURL string `mapstructure:"token_url" yaml:"token_url"` + UserInfoURL string `mapstructure:"user_info_url" yaml:"user_info_url"` +} + func NewConfig() *Config { return &Config{ Telegram: TelegramConfig{ @@ -126,9 +137,12 @@ func LoadConfig() *Config { } // get logger and log the current environment: fmt.Printf("--ConfigLoad config for environment: %s ", os.Getenv("DT_ENV")) + viper.SetEnvPrefix("DT") + viper.AutomaticEnv() viper.AddConfigPath("./config") viper.SetConfigType("yaml") + err := viper.ReadInConfig() // print a useful error: if err != nil { @@ -141,9 +155,11 @@ func LoadConfig() *Config { panic(err) } fmt.Printf("--ConfigLoad name : %s ", config.Name) - viper.SetEnvPrefix("DT") - viper.AutomaticEnv() + + // bind all the environment variables to the config: + configEnvironmentOverrides(&config) + panic(config.OAuth2Config.ClientID) return &config // return LocalConfig() diff --git a/config/selfhosted.env b/config/selfhosted.env index a03df9e..e5b6540 100644 --- a/config/selfhosted.env +++ b/config/selfhosted.env @@ -22,4 +22,10 @@ DT_EMAIL_HOST= DT_EMAIL_PORT= DT_EMAIL_KEY= DT_EMAIL_EMAIL= -DT_EMAIL_APP_HOST= \ No newline at end of file +DT_EMAIL_APP_HOST= +DT_OAUTH2_CLIENT_ID= +DT_OAUTH2_CLIENT_SECRET= +DT_OAUTH2_AUTH_URL= +DT_OAUTH2_TOKEN_URL= +DT_OAUTH2_USER_INFO_URL= +DT_OAUTH2_REDIRECT_URL= \ No newline at end of file diff --git a/config/selfhosted.yaml b/config/selfhosted.yaml index ca6c4af..e8641d1 100644 --- a/config/selfhosted.yaml +++ b/config/selfhosted.yaml @@ -35,3 +35,10 @@ email: key: email: appHost: +oauth2: + client_id: + client_secret: + auth_url: + token_url: + user_info_url: + redirect_url: diff --git a/go.mod b/go.mod index ce42755..654b0c5 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,8 @@ require ( github.com/gin-gonic/gin v1.10.0 github.com/glebarez/sqlite v1.11.0 github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1 + github.com/golang-jwt/jwt/v5 v5.2.1 + github.com/gregdel/pushover v1.3.1 github.com/rubenv/sql-migrate v1.7.0 github.com/spf13/viper v1.19.0 github.com/ulule/limiter/v3 v3.11.2 @@ -50,7 +52,6 @@ require ( github.com/google/uuid v1.6.0 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/gax-go/v2 v2.12.4 // indirect - github.com/gregdel/pushover v1.3.1 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect diff --git a/go.sum b/go.sum index 92d35cc..7cd6d88 100644 --- a/go.sum +++ b/go.sum @@ -72,6 +72,8 @@ github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= diff --git a/internal/authorization/identity_provider.go b/internal/authorization/identity_provider.go new file mode 100644 index 0000000..8936d84 --- /dev/null +++ b/internal/authorization/identity_provider.go @@ -0,0 +1,96 @@ +package auth + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + + "donetick.com/core/config" + "golang.org/x/oauth2" +) + +type IdentityProviderUserInfo struct { + Identifier string + DisplayName string + Email string +} + +type IdentityProvider struct { + config *config.OAuth2Config + isEnabled bool +} + +func NewIdentityProvider(cfg *config.Config) *IdentityProvider { + if cfg.OAuth2Config.ClientID == "" || cfg.OAuth2Config.ClientSecret == "" { + return &IdentityProvider{isEnabled: false} + } + return &IdentityProvider{config: &cfg.OAuth2Config, isEnabled: true} +} + +func (i *IdentityProvider) ExchangeToken(ctx context.Context, code string) (string, error) { + if !i.isEnabled { + return "", errors.New("identity provider is not enabled") + } + + conf := &oauth2.Config{ + ClientID: i.config.ClientID, + ClientSecret: i.config.ClientSecret, + RedirectURL: i.config.RedirectURL, + Scopes: i.config.Scopes, + Endpoint: oauth2.Endpoint{ + AuthURL: i.config.AuthURL, + TokenURL: i.config.TokenURL, + }, + } + + token, err := conf.Exchange(ctx, code) + if err != nil { + return "", err + } + + accessToken, ok := token.Extra("access_token").(string) + if !ok { + return "", errors.New("access token not found") + } + + return accessToken, nil +} + +func (i *IdentityProvider) GetUserInfo(ctx context.Context, accessToken string) (*IdentityProviderUserInfo, error) { + req, err := http.NewRequest("GET", i.config.UserInfoURL, nil) + if err != nil { + return nil, err + } + + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + var claims map[string]any + err = json.Unmarshal(body, &claims) + if err != nil { + return nil, errors.New("failed to unmarshal claims") + } + userInfo := IdentityProviderUserInfo{} + if val, ok := claims["sub"]; ok { + userInfo.Identifier = val.(string) + } + if val, ok := claims["name"]; ok { + userInfo.DisplayName = val.(string) + } + if val, ok := claims["email"]; ok { + userInfo.Email = val.(string) + } + return &userInfo, nil +} diff --git a/internal/thing/webhook.go b/internal/thing/api.go similarity index 87% rename from internal/thing/webhook.go rename to internal/thing/api.go index 03498cb..65c4f20 100644 --- a/internal/thing/webhook.go +++ b/internal/thing/api.go @@ -16,7 +16,7 @@ import ( "github.com/gin-gonic/gin" ) -type Webhook struct { +type API struct { choreRepo *chRepo.ChoreRepository circleRepo *cRepo.CircleRepository thingRepo *tRepo.ThingRepository @@ -24,9 +24,9 @@ type Webhook struct { tRepo *tRepo.ThingRepository } -func NewWebhook(cr *chRepo.ChoreRepository, circleRepo *cRepo.CircleRepository, - thingRepo *tRepo.ThingRepository, userRepo *uRepo.UserRepository, tRepo *tRepo.ThingRepository) *Webhook { - return &Webhook{ +func NewAPI(cr *chRepo.ChoreRepository, circleRepo *cRepo.CircleRepository, + thingRepo *tRepo.ThingRepository, userRepo *uRepo.UserRepository, tRepo *tRepo.ThingRepository) *API { + return &API{ choreRepo: cr, circleRepo: circleRepo, thingRepo: thingRepo, @@ -35,7 +35,7 @@ func NewWebhook(cr *chRepo.ChoreRepository, circleRepo *cRepo.CircleRepository, } } -func (h *Webhook) UpdateThingState(c *gin.Context) { +func (h *API) UpdateThingState(c *gin.Context) { thing, shouldReturn := validateUserAndThing(c, h) if shouldReturn { return @@ -60,7 +60,7 @@ func (h *Webhook) UpdateThingState(c *gin.Context) { c.JSON(200, gin.H{}) } -func (h *Webhook) ChangeThingState(c *gin.Context) { +func (h *API) ChangeThingState(c *gin.Context) { thing, shouldReturn := validateUserAndThing(c, h) if shouldReturn { return @@ -109,7 +109,7 @@ func (h *Webhook) ChangeThingState(c *gin.Context) { c.JSON(200, gin.H{"state": thing.State}) } -func WebhookEvaluateTriggerAndScheduleDueDate(h *Webhook, c *gin.Context, thing *tModel.Thing) bool { +func WebhookEvaluateTriggerAndScheduleDueDate(h *API, c *gin.Context, thing *tModel.Thing) bool { // handler should be interface to not duplicate both WebhookEvaluateTriggerAndScheduleDueDate and EvaluateTriggerAndScheduleDueDate // this is bad code written Saturday at 2:25 AM @@ -134,7 +134,7 @@ func WebhookEvaluateTriggerAndScheduleDueDate(h *Webhook, c *gin.Context, thing return false } -func validateUserAndThing(c *gin.Context, h *Webhook) (*tModel.Thing, bool) { +func validateUserAndThing(c *gin.Context, h *API) (*tModel.Thing, bool) { apiToken := c.GetHeader("secretkey") if apiToken == "" { c.JSON(401, gin.H{"error": "Unauthorized"}) @@ -162,7 +162,7 @@ func validateUserAndThing(c *gin.Context, h *Webhook) (*tModel.Thing, bool) { return thing, false } -func Webhooks(cfg *config.Config, w *Webhook, r *gin.Engine, auth *jwt.GinJWTMiddleware) { +func APIs(cfg *config.Config, w *API, r *gin.Engine, auth *jwt.GinJWTMiddleware) { thingsAPI := r.Group("eapi/v1/things") diff --git a/internal/user/handler.go b/internal/user/handler.go index 673659e..b7d33f6 100644 --- a/internal/user/handler.go +++ b/internal/user/handler.go @@ -31,16 +31,18 @@ type Handler struct { circleRepo *cRepo.CircleRepository jwtAuth *jwt.GinJWTMiddleware email *email.EmailSender + identityProvider *auth.IdentityProvider isDonetickDotCom bool IsUserCreationDisabled bool } -func NewHandler(ur *uRepo.UserRepository, cr *cRepo.CircleRepository, jwtAuth *jwt.GinJWTMiddleware, email *email.EmailSender, config *config.Config) *Handler { +func NewHandler(ur *uRepo.UserRepository, cr *cRepo.CircleRepository, jwtAuth *jwt.GinJWTMiddleware, email *email.EmailSender, idp *auth.IdentityProvider, config *config.Config) *Handler { return &Handler{ userRepo: ur, circleRepo: cr, jwtAuth: jwtAuth, email: email, + identityProvider: idp, isDonetickDotCom: config.IsDoneTickDotCom, IsUserCreationDisabled: config.IsUserCreationDisabled, } @@ -178,7 +180,8 @@ func (h *Handler) thirdPartyAuthCallback(c *gin.Context) { provider := c.Param("provider") logger.Infow("account.handler.thirdPartyAuthCallback", "provider", provider) - if provider == "google" { + switch provider { + case "google": c.Set("auth_provider", "3rdPartyAuth") type OAuthRequest struct { Token string `json:"token" binding:"required"` @@ -219,7 +222,7 @@ func (h *Handler) thirdPartyAuthCallback(c *gin.Context) { Image: userinfo.Picture, Password: encodedPassword, DisplayName: userinfo.GivenName, - Provider: 2, + Provider: uModel.AuthProviderGoogle, } createdUser, err := h.userRepo.CreateUser(c, acc) if err != nil { @@ -278,6 +281,106 @@ func (h *Handler) thirdPartyAuthCallback(c *gin.Context) { } c.JSON(http.StatusOK, gin.H{"token": tokenString, "expire": expire}) return + case "oauth2": + c.Set("auth_provider", "3rdPartyAuth") + // Read the ID token from the request bod + type Request struct { + Code string `json:"code"` + } + var req Request + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"}) + return + } + + token, err := h.identityProvider.ExchangeToken(c, req.Code) + if err != nil { + logger.Errorw("account.handler.thirdPartyAuthCallback (oauth2) failed to exchange token", "err", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to exchange token"}) + return + } + claims, err := h.identityProvider.GetUserInfo(c, token) + if err != nil { + logger.Errorw("account.handler.thirdPartyAuthCallback (oauth2) failed to get claims", "err", err) + } + + acc, err := h.userRepo.FindByEmail(c, claims.Email) + if err != nil { + // Create user + password := auth.GenerateRandomPassword(12) + encodedPassword, err := auth.EncodePassword(password) + if err != nil { + logger.Errorw("account.handler.thirdPartyAuthCallback (oauth2) password encoding failed", "err", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Password encoding failed"}) + return + } + acc = &uModel.User{ + Username: claims.Email, + Email: claims.Email, + Password: encodedPassword, + DisplayName: claims.DisplayName, + Provider: uModel.AuthProviderOAuth2, + } + createdUser, err := h.userRepo.CreateUser(c, acc) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Unable to create user", + }) + return + + } + // Create Circle for the user: + userCircle, err := h.circleRepo.CreateCircle(c, &cModel.Circle{ + Name: claims.DisplayName + "'s circle", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + InviteCode: utils.GenerateInviteCode(c), + }) + + if err != nil { + c.JSON(500, gin.H{ + "error": "Error creating circle", + }) + return + } + + if err := h.circleRepo.AddUserToCircle(c, &cModel.UserCircle{ + UserID: createdUser.ID, + CircleID: userCircle.ID, + Role: "admin", + IsActive: true, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }); err != nil { + c.JSON(500, gin.H{ + "error": "Error adding user to circle", + }) + return + } + createdUser.CircleID = userCircle.ID + if err := h.userRepo.UpdateUser(c, createdUser); err != nil { + c.JSON(500, gin.H{ + "error": "Error updating user", + }) + return + } + } + // ... (JWT generation and response) + c.Set("user_account", acc) + h.jwtAuth.Authenticator(c) + tokenString, expire, err := h.jwtAuth.TokenGenerator(acc) + if err != nil { + logger.Errorw("Unable to Generate a Token") + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Unable to Generate a Token", + }) + return + } + c.JSON(http.StatusOK, gin.H{"token": tokenString, "expire": expire}) + return + default: + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider"}) + return } } diff --git a/internal/user/model/model.go b/internal/user/model/model.go index 6a4c434..e57069d 100644 --- a/internal/user/model/model.go +++ b/internal/user/model/model.go @@ -7,18 +7,18 @@ import ( ) type User struct { - ID int `json:"id" gorm:"primary_key"` // Unique identifier - DisplayName string `json:"displayName" gorm:"column:display_name"` // Display name - Username string `json:"username" gorm:"column:username;unique"` // Username (unique) - Email string `json:"email" gorm:"column:email;unique"` // Email (unique) - Provider int `json:"provider" gorm:"column:provider"` // Provider - Password string `json:"-" gorm:"column:password"` // Password - CircleID int `json:"circleID" gorm:"column:circle_id"` // Circle ID - ChatID int64 `json:"chatID" gorm:"column:chat_id"` // Telegram chat ID - Image string `json:"image" gorm:"column:image"` // Image - CreatedAt time.Time `json:"created_at" gorm:"column:created_at"` // Created at - UpdatedAt time.Time `json:"updated_at" gorm:"column:updated_at"` // Updated at - Disabled bool `json:"disabled" gorm:"column:disabled"` // Disabled + ID int `json:"id" gorm:"primary_key"` // Unique identifier + DisplayName string `json:"displayName" gorm:"column:display_name"` // Display name + Username string `json:"username" gorm:"column:username;unique"` // Username (unique) + Email string `json:"email" gorm:"column:email;unique"` // Email (unique) + Provider AuthProvider `json:"provider" gorm:"column:provider"` // Provider + Password string `json:"-" gorm:"column:password"` // Password + CircleID int `json:"circleID" gorm:"column:circle_id"` // Circle ID + ChatID int64 `json:"chatID" gorm:"column:chat_id"` // Telegram chat ID + Image string `json:"image" gorm:"column:image"` // Image + CreatedAt time.Time `json:"created_at" gorm:"column:created_at"` // Created at + UpdatedAt time.Time `json:"updated_at" gorm:"column:updated_at"` // Updated at + Disabled bool `json:"disabled" gorm:"column:disabled"` // Disabled // Email string `json:"email" gorm:"column:email"` // Email CustomerID *string `gorm:"column:customer_id;<-:false"` // read only column Subscription *string `json:"subscription" gorm:"column:subscription;<-:false"` // read only column @@ -48,3 +48,10 @@ type UserNotificationTarget struct { TargetID string `json:"target_id" gorm:"column:target_id"` // Target ID CreatedAt time.Time `json:"-" gorm:"column:created_at"` } +type AuthProvider int + +const ( + AuthProviderDonetick AuthProvider = iota + AuthProviderOAuth2 + AuthProviderGoogle +) diff --git a/main.go b/main.go index 3772ebd..9713efd 100644 --- a/main.go +++ b/main.go @@ -52,6 +52,7 @@ func main() { // fx.Provide(config.NewConfig), fx.Provide(auth.NewAuthMiddleware), + fx.Provide(auth.NewIdentityProvider), // fx.Provide(NewBot), fx.Provide(database.NewDatabase), @@ -89,7 +90,7 @@ func main() { fx.Provide(lRepo.NewLabelRepository), fx.Provide(label.NewHandler), - fx.Provide(thing.NewWebhook), + fx.Provide(thing.NewAPI), fx.Provide(thing.NewHandler), fx.Provide(chore.NewAPI), @@ -103,7 +104,7 @@ func main() { user.Routes, circle.Routes, thing.Routes, - thing.Webhooks, + thing.APIs, label.Routes, frontend.Routes,