diff --git a/apis/system/sysuser.go b/apis/system/sysuser.go index eec4f7b..2ab37a6 100644 --- a/apis/system/sysuser.go +++ b/apis/system/sysuser.go @@ -2,6 +2,7 @@ package system import ( "ferry/models/system" + "ferry/pkg/ldap" "ferry/pkg/logger" "ferry/tools" "ferry/tools/app" @@ -288,12 +289,22 @@ func SysUserUpdatePwd(c *gin.Context) { app.Error(c, -1, err, "") return } - sysuser := system.SysUser{} - sysuser.UserId = tools.GetUserId(c) - _, err = sysuser.SetPwd(pwd) - if err != nil { - app.Error(c, -1, err, "") - return + if pwd.PasswordType == 0 { + sysuser := system.SysUser{} + sysuser.UserId = tools.GetUserId(c) + _, err = sysuser.SetPwd(pwd) + if err != nil { + app.Error(c, -1, err, "") + return + } + } else if pwd.PasswordType == 1 { + // 修改ldap密码 + err = ldap.LdapUpdatePwd(tools.GetUserName(c), pwd.OldPassword, pwd.NewPassword) + if err != nil { + app.Error(c, -1, err, "") + return + } } + app.OK(c, "", "密码修改成功") } diff --git a/go.mod b/go.mod index ec5f7eb..9266531 100644 --- a/go.mod +++ b/go.mod @@ -33,6 +33,7 @@ require ( github.com/unrolled/secure v1.0.8 go.uber.org/zap v1.10.0 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 + golang.org/x/text v0.3.3 gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df diff --git a/handler/auth.go b/handler/auth.go index 4644d66..2f28cd7 100644 --- a/handler/auth.go +++ b/handler/auth.go @@ -1,6 +1,7 @@ package handler import ( + "errors" "ferry/global/orm" "ferry/models/system" jwt "ferry/pkg/jwtauth" @@ -9,6 +10,7 @@ import ( "ferry/tools" "fmt" "net/http" + "time" "github.com/gin-gonic/gin" "github.com/mojocn/base64Captcha" @@ -60,15 +62,9 @@ func Authenticator(c *gin.Context) (interface{}, error) { loginLog system.LoginLog roleValue system.SysRole authUserCount int - l = ldap.Connection{} - userInfo system.SysUser - addUserInfo struct { - Username string `json:"username"` - RoleId int `json:"role_id"` - } + addUserInfo system.SysUser ) - loginType := c.DefaultQuery("login_type", "0") ua := user_agent.New(c.Request.UserAgent()) loginLog.Ipaddr = c.ClientIP() location := tools.GetLocation(c.ClientIP()) @@ -84,7 +80,6 @@ func Authenticator(c *gin.Context) (interface{}, error) { // 获取前端过来的数据 if err := c.ShouldBind(&loginVal); err != nil { - fmt.Println("********** " + err.Error() + " **********") loginLog.Status = "1" loginLog.Msg = "数据解析失败" loginLog.Username = loginVal.Username @@ -102,30 +97,33 @@ func Authenticator(c *gin.Context) (interface{}, error) { } // ldap 验证 - if loginType == "1" { + if loginVal.LoginType == 1 { // ldap登陆 - err = l.LdapLogin(loginVal.Username, loginVal.Password) + err = ldap.LdapLogin(loginVal.Username, loginVal.Password) if err != nil { - return nil, jwt.ErrInvalidVerificationode + return nil, err } // 2. 将ldap用户信息写入到用户数据表中 err = orm.Eloquent.Table("sys_user"). - Where("username = ?", userInfo.Username). + Where("username = ?", loginVal.Username). Count(&authUserCount).Error if err != nil { - return nil, jwt.ErrInvalidVerificationode + return nil, errors.New(fmt.Sprintf("查询用户失败,%v", err)) } if authUserCount == 0 { - addUserInfo.Username = userInfo.Username + addUserInfo.Username = loginVal.Username // 获取默认权限ID err = orm.Eloquent.Table("sys_role").Where("role_key = 'common'").Find(&roleValue).Error if err != nil { - return nil, jwt.ErrInvalidVerificationode + return nil, errors.New(fmt.Sprintf("查询角色失败,%v", err)) } addUserInfo.RoleId = roleValue.RoleId // 绑定通用角色 + addUserInfo.Status = "0" + addUserInfo.CreatedAt = time.Now() + addUserInfo.UpdatedAt = time.Now() err = orm.Eloquent.Table("sys_user").Create(&addUserInfo).Error if err != nil { - return nil, jwt.ErrInvalidVerificationode + return nil, errors.New(fmt.Sprintf("创建本地用户失败,%v", err)) } } } diff --git a/models/system/login.go b/models/system/login.go index 25d7d04..f799a01 100644 --- a/models/system/login.go +++ b/models/system/login.go @@ -10,10 +10,11 @@ import ( */ type Login struct { - Username string `form:"UserName" json:"username" binding:"required"` - Password string `form:"Password" json:"password" binding:"required"` - Code string `form:"Code" json:"code" binding:"required"` - UUID string `form:"UUID" json:"uuid" binding:"required"` + Username string `form:"UserName" json:"username" binding:"required"` + Password string `form:"Password" json:"password" binding:"required"` + Code string `form:"Code" json:"code" binding:"required"` + UUID string `form:"UUID" json:"uuid" binding:"required"` + LoginType int `form:"LoginType" json:"loginType"` } func (u *Login) GetUser() (user SysUser, role SysRole, e error) { @@ -22,10 +23,15 @@ func (u *Login) GetUser() (user SysUser, role SysRole, e error) { if e != nil { return } - _, e = tools.CompareHashAndPassword(user.Password, u.Password) - if e != nil { - return + + // 验证密码 + if u.LoginType == 0 { + _, e = tools.CompareHashAndPassword(user.Password, u.Password) + if e != nil { + return + } } + e = orm.Eloquent.Table("sys_role").Where("role_id = ? ", user.RoleId).First(&role).Error if e != nil { return diff --git a/models/system/sysuser.go b/models/system/sysuser.go index 1378366..614054f 100644 --- a/models/system/sysuser.go +++ b/models/system/sysuser.go @@ -73,8 +73,9 @@ func (SysUser) TableName() string { } type SysUserPwd struct { - OldPassword string `json:"oldPassword"` - NewPassword string `json:"newPassword"` + OldPassword string `json:"oldPassword" form:"oldPassword"` + NewPassword string `json:"newPassword" form:"newPassword"` + PasswordType int `json:"passwordType" form:"passwordType"` } type SysUserPage struct { diff --git a/pkg/ldap/connection.go b/pkg/ldap/connection.go index 1f276e0..419df53 100644 --- a/pkg/ldap/connection.go +++ b/pkg/ldap/connection.go @@ -2,6 +2,7 @@ package ldap import ( "crypto/tls" + "errors" "ferry/pkg/logger" "fmt" "time" @@ -15,34 +16,33 @@ import ( @Author : lanyulei */ -type Connection struct { - Conn *ldap.Conn -} +var conn *ldap.Conn // ldap连接 -func (c *Connection) ldapConnection() (err error) { +func ldapConnection() (err error) { var ldapConn = fmt.Sprintf("%v:%v", viper.GetString("settings.ldap.host"), viper.GetString("settings.ldap.port")) if viper.GetInt("settings.ldap.port") == 636 { - c.Conn, err = ldap.DialTLS( + conn, err = ldap.DialTLS( "tcp", ldapConn, &tls.Config{InsecureSkipVerify: true}, ) } else { - c.Conn, err = ldap.Dial( + conn, err = ldap.Dial( "tcp", ldapConn, ) } if err != nil { - logger.Errorf("无法连接到ldap服务器,%v", err) + err = errors.New(fmt.Sprintf("无法连接到ldap服务器,%v", err)) + logger.Error(err) return } //设置超时时间 - c.Conn.SetTimeout(5 * time.Second) + conn.SetTimeout(5 * time.Second) return } diff --git a/pkg/ldap/login.go b/pkg/ldap/login.go index d821e7d..522af6c 100644 --- a/pkg/ldap/login.go +++ b/pkg/ldap/login.go @@ -11,14 +11,14 @@ import ( @Author : lanyulei */ -func (c *Connection) LdapLogin(username string, password string) (err error) { - err = c.ldapConnection() +func LdapLogin(username string, password string) (err error) { + err = ldapConnection() if err != nil { return } - defer c.Conn.Close() + defer conn.Close() - err = c.Conn.Bind(fmt.Sprintf("cn=%v,%v", username, viper.GetString("settings.ldap.baseDn")), password) + err = conn.Bind(fmt.Sprintf("cn=%v,%v", username, viper.GetString("settings.ldap.baseDn")), password) if err != nil { logger.Error("用户或密码错误。", err) return diff --git a/pkg/ldap/updatePwd.go b/pkg/ldap/updatePwd.go new file mode 100644 index 0000000..ccdb63c --- /dev/null +++ b/pkg/ldap/updatePwd.go @@ -0,0 +1,46 @@ +package ldap + +import ( + "ferry/pkg/logger" + "fmt" + + "github.com/go-ldap/ldap/v3" + "golang.org/x/text/encoding/unicode" + + "github.com/spf13/viper" +) + +/* + @Author : lanyulei +*/ + +func LdapUpdatePwd(username string, oldPassword string, newPassword string) (err error) { + err = ldapConnection() + if err != nil { + return + } + defer conn.Close() + + var userDn = fmt.Sprintf("cn=%v,%v", username, viper.GetString("settings.ldap.baseDn")) + + err = conn.Bind(userDn, oldPassword) + if err != nil { + logger.Error("用户或密码错误。", err) + return + } + + sql2 := ldap.NewModifyRequest(userDn, nil) + + utf16 := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM) + pwdEncoded, _ := utf16.NewEncoder().String(newPassword) + + sql2.Replace("unicodePwd", []string{pwdEncoded}) + sql2.Replace("userAccountControl", []string{"512"}) + + if err = conn.Modify(sql2); err != nil { + logger.Error("密码修改失败,%v", err.Error()) + return + } + + return +}