From 4c1b402137fa9f6a38ff7ae154669c734b1f1e2a Mon Sep 17 00:00:00 2001 From: Mo Tarbin Date: Thu, 6 Feb 2025 21:53:34 -0500 Subject: [PATCH] Update config to support OIDC and oauth2 --- config/config.go | 19 +++++--- config/local.yaml | 4 +- config/selfhosted.yaml | 1 + internal/authorization/identity_provider.go | 3 +- internal/resource/handler.go | 54 +++++++++++++++++++++ internal/user/handler.go | 22 +++------ internal/user/model/model.go | 28 +++++------ main.go | 3 ++ 8 files changed, 92 insertions(+), 42 deletions(-) create mode 100644 internal/resource/handler.go diff --git a/config/config.go b/config/config.go index c91664c..ecadfa3 100644 --- a/config/config.go +++ b/config/config.go @@ -3,6 +3,7 @@ package config import ( "fmt" "os" + "strings" "time" "github.com/spf13/viper" @@ -86,13 +87,14 @@ type EmailConfig struct { } type OAuth2Config struct { - ClientID string `mapstructure:"client_id" yaml:"client_id"` - ClientSecret string `mapstructure:"client_secret" yaml:"client_secret"` - RedirectURL string `mapstructure:"redirect_url" yaml:"redirect_url"` - Scopes []string - AuthURL string `mapstructure:"auth_url" yaml:"auth_url"` - TokenURL string `mapstructure:"token_url" yaml:"token_url"` - UserInfoURL string `mapstructure:"user_info_url" yaml:"user_info_url"` + ClientID string `mapstructure:"client_id" yaml:"client_id"` + ClientSecret string `mapstructure:"client_secret" yaml:"client_secret"` + RedirectURL string `mapstructure:"redirect_url" yaml:"redirect_url"` + Scopes []string `mapstructure:"scopes" yaml:"scopes"` + AuthURL string `mapstructure:"auth_url" yaml:"auth_url"` + TokenURL string `mapstructure:"token_url" yaml:"token_url"` + UserInfoURL string `mapstructure:"user_info_url" yaml:"user_info_url"` + Name string `mapstructure:"name" yaml:"name"` } func NewConfig() *Config { @@ -138,6 +140,7 @@ func LoadConfig() *Config { // get logger and log the current environment: fmt.Printf("--ConfigLoad config for environment: %s ", os.Getenv("DT_ENV")) viper.SetEnvPrefix("DT") + viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) viper.AutomaticEnv() viper.AddConfigPath("./config") @@ -155,11 +158,11 @@ func LoadConfig() *Config { panic(err) } fmt.Printf("--ConfigLoad name : %s ", config.Name) + panic(config.OAuth2Config.ClientID) // bind all the environment variables to the config: configEnvironmentOverrides(&config) - panic(config.OAuth2Config.ClientID) return &config // return LocalConfig() diff --git a/config/local.yaml b/config/local.yaml index 7018f24..f9c888f 100644 --- a/config/local.yaml +++ b/config/local.yaml @@ -14,8 +14,8 @@ jwt: max_refresh: 168h server: port: 2021 - read_timeout: 2s - write_timeout: 1s + read_timeout: 10s + write_timeout: 10s rate_period: 60s rate_limit: 300 cors_allow_origins: diff --git a/config/selfhosted.yaml b/config/selfhosted.yaml index e8641d1..26e2eb5 100644 --- a/config/selfhosted.yaml +++ b/config/selfhosted.yaml @@ -42,3 +42,4 @@ oauth2: token_url: user_info_url: redirect_url: + name: \ No newline at end of file diff --git a/internal/authorization/identity_provider.go b/internal/authorization/identity_provider.go index 8936d84..6714bfe 100644 --- a/internal/authorization/identity_provider.go +++ b/internal/authorization/identity_provider.go @@ -44,13 +44,12 @@ func (i *IdentityProvider) ExchangeToken(ctx context.Context, code string) (stri TokenURL: i.config.TokenURL, }, } - token, err := conf.Exchange(ctx, code) if err != nil { return "", err } - accessToken, ok := token.Extra("access_token").(string) + accessToken, ok := token.AccessToken, token.Valid() if !ok { return "", errors.New("access token not found") } diff --git a/internal/resource/handler.go b/internal/resource/handler.go new file mode 100644 index 0000000..3038044 --- /dev/null +++ b/internal/resource/handler.go @@ -0,0 +1,54 @@ +package resource + +import ( + "donetick.com/core/config" + jwt "github.com/appleboy/gin-jwt/v2" + "github.com/gin-gonic/gin" +) + +type Resource struct { + Idp identityProvider `json:"identity_provider" binding:"omitempty"` +} +type identityProvider struct { + Auth_url string `json:"auth_url" binding:"omitempty"` + Client_ID string `json:"client_id" binding:"omitempty"` + Name string `json:"name" binding:"omitempty"` +} + +type Handler struct { + config config.Config +} + +func NewHandler(cfg *config.Config) *Handler { + return &Handler{ + config: *cfg, + } +} + +func (h *Handler) getResource(c *gin.Context) { + c.JSON(200, &Resource{ + Idp: identityProvider{ + Auth_url: h.config.OAuth2Config.AuthURL, + Client_ID: h.config.OAuth2Config.ClientID, + Name: h.config.OAuth2Config.Name, + }, + }, + ) +} + +func Routes(r *gin.Engine, h *Handler, auth *jwt.GinJWTMiddleware, cfg *config.Config) { + resourceRoutes := r.Group("api/v1/resource") + + if cfg.IsDoneTickDotCom { + // skip resource endpoint for donetick.com + resourceRoutes.GET("", func(c *gin.Context) { + c.JSON(200, gin.H{}) + }) + return + } + + { + resourceRoutes.GET("", h.getResource) + } + +} diff --git a/internal/user/handler.go b/internal/user/handler.go index b7d33f6..ed87890 100644 --- a/internal/user/handler.go +++ b/internal/user/handler.go @@ -294,14 +294,16 @@ func (h *Handler) thirdPartyAuthCallback(c *gin.Context) { } token, err := h.identityProvider.ExchangeToken(c, req.Code) + if err != nil { - logger.Errorw("account.handler.thirdPartyAuthCallback (oauth2) failed to exchange token", "err", err) + logger.Error("account.handler.thirdPartyAuthCallback (oauth2) failed to exchange token", "err", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to exchange token"}) return } + claims, err := h.identityProvider.GetUserInfo(c, token) if err != nil { - logger.Errorw("account.handler.thirdPartyAuthCallback (oauth2) failed to get claims", "err", err) + logger.Error("account.handler.thirdPartyAuthCallback (oauth2) failed to get claims", "err", err) } acc, err := h.userRepo.FindByEmail(c, claims.Email) @@ -310,7 +312,7 @@ func (h *Handler) thirdPartyAuthCallback(c *gin.Context) { password := auth.GenerateRandomPassword(12) encodedPassword, err := auth.EncodePassword(password) if err != nil { - logger.Errorw("account.handler.thirdPartyAuthCallback (oauth2) password encoding failed", "err", err) + logger.Error("account.handler.thirdPartyAuthCallback (oauth2) password encoding failed", "err", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "Password encoding failed"}) return } @@ -370,7 +372,7 @@ func (h *Handler) thirdPartyAuthCallback(c *gin.Context) { h.jwtAuth.Authenticator(c) tokenString, expire, err := h.jwtAuth.TokenGenerator(acc) if err != nil { - logger.Errorw("Unable to Generate a Token") + logger.Error("Unable to Generate a Token") c.JSON(http.StatusInternalServerError, gin.H{ "error": "Unable to Generate a Token", }) @@ -378,9 +380,6 @@ func (h *Handler) thirdPartyAuthCallback(c *gin.Context) { } c.JSON(http.StatusOK, gin.H{"token": tokenString, "expire": expire}) return - default: - c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid provider"}) - return } } @@ -713,13 +712,4 @@ func Routes(router *gin.Engine, h *Handler, auth *jwt.GinJWTMiddleware, limiter authRoutes.POST("reset", h.resetPassword) authRoutes.POST("password", h.updateUserPassword) } - pingRoutes := router.Group("api/v1/ping") - pingRoutes.Use(utils.RateLimitMiddleware(limiter)) - { - pingRoutes.GET("/", func(c *gin.Context) { - c.JSON(200, gin.H{ - "message": "pong", - }) - }) - } } diff --git a/internal/user/model/model.go b/internal/user/model/model.go index e57069d..5d70e98 100644 --- a/internal/user/model/model.go +++ b/internal/user/model/model.go @@ -7,18 +7,18 @@ import ( ) type User struct { - ID int `json:"id" gorm:"primary_key"` // Unique identifier - DisplayName string `json:"displayName" gorm:"column:display_name"` // Display name - Username string `json:"username" gorm:"column:username;unique"` // Username (unique) - Email string `json:"email" gorm:"column:email;unique"` // Email (unique) - Provider AuthProvider `json:"provider" gorm:"column:provider"` // Provider - Password string `json:"-" gorm:"column:password"` // Password - CircleID int `json:"circleID" gorm:"column:circle_id"` // Circle ID - ChatID int64 `json:"chatID" gorm:"column:chat_id"` // Telegram chat ID - Image string `json:"image" gorm:"column:image"` // Image - CreatedAt time.Time `json:"created_at" gorm:"column:created_at"` // Created at - UpdatedAt time.Time `json:"updated_at" gorm:"column:updated_at"` // Updated at - Disabled bool `json:"disabled" gorm:"column:disabled"` // Disabled + ID int `json:"id" gorm:"primary_key"` // Unique identifier + DisplayName string `json:"displayName" gorm:"column:display_name"` // Display name + Username string `json:"username" gorm:"column:username;unique"` // Username (unique) + Email string `json:"email" gorm:"column:email;unique"` // Email (unique) + Provider AuthProviderType `json:"provider" gorm:"column:provider"` // Provider + Password string `json:"-" gorm:"column:password"` // Password + CircleID int `json:"circleID" gorm:"column:circle_id"` // Circle ID + ChatID int64 `json:"chatID" gorm:"column:chat_id"` // Telegram chat ID + Image string `json:"image" gorm:"column:image"` // Image + CreatedAt time.Time `json:"created_at" gorm:"column:created_at"` // Created at + UpdatedAt time.Time `json:"updated_at" gorm:"column:updated_at"` // Updated at + Disabled bool `json:"disabled" gorm:"column:disabled"` // Disabled // Email string `json:"email" gorm:"column:email"` // Email CustomerID *string `gorm:"column:customer_id;<-:false"` // read only column Subscription *string `json:"subscription" gorm:"column:subscription;<-:false"` // read only column @@ -48,10 +48,10 @@ type UserNotificationTarget struct { TargetID string `json:"target_id" gorm:"column:target_id"` // Target ID CreatedAt time.Time `json:"-" gorm:"column:created_at"` } -type AuthProvider int +type AuthProviderType int const ( - AuthProviderDonetick AuthProvider = iota + AuthProviderDonetick AuthProviderType = iota AuthProviderOAuth2 AuthProviderGoogle ) diff --git a/main.go b/main.go index 9713efd..a168f39 100644 --- a/main.go +++ b/main.go @@ -24,6 +24,7 @@ import ( "donetick.com/core/internal/email" label "donetick.com/core/internal/label" lRepo "donetick.com/core/internal/label/repo" + "donetick.com/core/internal/resource" notifier "donetick.com/core/internal/notifier" nRepo "donetick.com/core/internal/notifier/repo" @@ -53,6 +54,7 @@ func main() { // fx.Provide(config.NewConfig), fx.Provide(auth.NewAuthMiddleware), fx.Provide(auth.NewIdentityProvider), + fx.Provide(resource.NewHandler), // fx.Provide(NewBot), fx.Provide(database.NewDatabase), @@ -107,6 +109,7 @@ func main() { thing.APIs, label.Routes, frontend.Routes, + resource.Routes, func(r *gin.Engine) {}, ),