oidc_test.go 57 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774
  1. // Copyright (C) 2019-2023 Nicola Murino
  2. //
  3. // This program is free software: you can redistribute it and/or modify
  4. // it under the terms of the GNU Affero General Public License as published
  5. // by the Free Software Foundation, version 3.
  6. //
  7. // This program is distributed in the hope that it will be useful,
  8. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. // GNU Affero General Public License for more details.
  11. //
  12. // You should have received a copy of the GNU Affero General Public License
  13. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. package httpd
  15. import (
  16. "bytes"
  17. "context"
  18. "encoding/json"
  19. "fmt"
  20. "io/fs"
  21. "net/http"
  22. "net/http/httptest"
  23. "net/url"
  24. "os"
  25. "path/filepath"
  26. "reflect"
  27. "runtime"
  28. "testing"
  29. "time"
  30. "unsafe"
  31. "github.com/coreos/go-oidc/v3/oidc"
  32. "github.com/go-chi/jwtauth/v5"
  33. "github.com/lestrrat-go/jwx/v2/jwa"
  34. "github.com/rs/xid"
  35. "github.com/sftpgo/sdk"
  36. "github.com/stretchr/testify/assert"
  37. "github.com/stretchr/testify/require"
  38. "golang.org/x/oauth2"
  39. "github.com/drakkan/sftpgo/v2/internal/common"
  40. "github.com/drakkan/sftpgo/v2/internal/dataprovider"
  41. "github.com/drakkan/sftpgo/v2/internal/kms"
  42. "github.com/drakkan/sftpgo/v2/internal/util"
  43. "github.com/drakkan/sftpgo/v2/internal/vfs"
  44. )
  45. const (
  46. oidcMockAddr = "127.0.0.1:11111"
  47. )
  48. type mockTokenSource struct {
  49. token *oauth2.Token
  50. err error
  51. }
  52. func (t *mockTokenSource) Token() (*oauth2.Token, error) {
  53. return t.token, t.err
  54. }
  55. type mockOAuth2Config struct {
  56. tokenSource *mockTokenSource
  57. authCodeURL string
  58. token *oauth2.Token
  59. err error
  60. }
  61. func (c *mockOAuth2Config) AuthCodeURL(_ string, _ ...oauth2.AuthCodeOption) string {
  62. return c.authCodeURL
  63. }
  64. func (c *mockOAuth2Config) Exchange(_ context.Context, _ string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
  65. return c.token, c.err
  66. }
  67. func (c *mockOAuth2Config) TokenSource(_ context.Context, _ *oauth2.Token) oauth2.TokenSource {
  68. return c.tokenSource
  69. }
  70. type mockOIDCVerifier struct {
  71. token *oidc.IDToken
  72. err error
  73. }
  74. func (v *mockOIDCVerifier) Verify(_ context.Context, _ string) (*oidc.IDToken, error) {
  75. return v.token, v.err
  76. }
  77. // hack because the field is unexported
  78. func setIDTokenClaims(idToken *oidc.IDToken, claims []byte) {
  79. pointerVal := reflect.ValueOf(idToken)
  80. val := reflect.Indirect(pointerVal)
  81. member := val.FieldByName("claims")
  82. ptr := unsafe.Pointer(member.UnsafeAddr())
  83. realPtr := (*[]byte)(ptr)
  84. *realPtr = claims
  85. }
  86. func TestOIDCInitialization(t *testing.T) {
  87. config := OIDC{}
  88. err := config.initialize()
  89. assert.NoError(t, err)
  90. secret := "jRsmE0SWnuZjP7djBqNq0mrf8QN77j2c"
  91. config = OIDC{
  92. ClientID: "sftpgo-client",
  93. ClientSecret: util.GenerateUniqueID(),
  94. ConfigURL: fmt.Sprintf("http://%v/", oidcMockAddr),
  95. RedirectBaseURL: "http://127.0.0.1:8081/",
  96. UsernameField: "preferred_username",
  97. RoleField: "sftpgo_role",
  98. }
  99. err = config.initialize()
  100. if assert.Error(t, err) {
  101. assert.Contains(t, err.Error(), "oidc: required scope \"openid\" is not set")
  102. }
  103. config.Scopes = []string{oidc.ScopeOpenID}
  104. config.ClientSecretFile = "missing file"
  105. err = config.initialize()
  106. assert.ErrorIs(t, err, fs.ErrNotExist)
  107. secretFile := filepath.Join(os.TempDir(), util.GenerateUniqueID())
  108. defer os.Remove(secretFile)
  109. err = os.WriteFile(secretFile, []byte(secret), 0600)
  110. assert.NoError(t, err)
  111. config.ClientSecretFile = secretFile
  112. err = config.initialize()
  113. if assert.Error(t, err) {
  114. assert.Contains(t, err.Error(), "oidc: unable to initialize provider")
  115. }
  116. assert.Equal(t, secret, config.ClientSecret)
  117. config.ConfigURL = fmt.Sprintf("http://%v/auth/realms/sftpgo", oidcMockAddr)
  118. err = config.initialize()
  119. assert.NoError(t, err)
  120. assert.Equal(t, "http://127.0.0.1:8081"+webOIDCRedirectPath, config.getRedirectURL())
  121. }
  122. func TestOIDCLoginLogout(t *testing.T) {
  123. oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
  124. require.True(t, ok)
  125. server := getTestOIDCServer()
  126. err := server.binding.OIDC.initialize()
  127. assert.NoError(t, err)
  128. server.initializeRouter()
  129. rr := httptest.NewRecorder()
  130. r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath, nil)
  131. assert.NoError(t, err)
  132. server.router.ServeHTTP(rr, r)
  133. assert.Equal(t, http.StatusBadRequest, rr.Code)
  134. assert.Contains(t, rr.Body.String(), util.I18nInvalidAuth)
  135. expiredAuthReq := oidcPendingAuth{
  136. State: xid.New().String(),
  137. Nonce: xid.New().String(),
  138. Audience: tokenAudienceWebClient,
  139. IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-10 * time.Minute)),
  140. }
  141. oidcMgr.addPendingAuth(expiredAuthReq)
  142. rr = httptest.NewRecorder()
  143. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+expiredAuthReq.State, nil)
  144. assert.NoError(t, err)
  145. server.router.ServeHTTP(rr, r)
  146. assert.Equal(t, http.StatusBadRequest, rr.Code)
  147. assert.Contains(t, rr.Body.String(), util.I18nInvalidAuth)
  148. oidcMgr.removePendingAuth(expiredAuthReq.State)
  149. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  150. tokenSource: &mockTokenSource{},
  151. authCodeURL: webOIDCRedirectPath,
  152. err: common.ErrGenericFailure,
  153. }
  154. server.binding.OIDC.verifier = &mockOIDCVerifier{
  155. err: common.ErrGenericFailure,
  156. }
  157. rr = httptest.NewRecorder()
  158. r, err = http.NewRequest(http.MethodGet, webAdminOIDCLoginPath, nil)
  159. assert.NoError(t, err)
  160. server.router.ServeHTTP(rr, r)
  161. assert.Equal(t, http.StatusFound, rr.Code)
  162. assert.Equal(t, webOIDCRedirectPath, rr.Header().Get("Location"))
  163. require.Len(t, oidcMgr.pendingAuths, 1)
  164. var state string
  165. for k := range oidcMgr.pendingAuths {
  166. state = k
  167. }
  168. rr = httptest.NewRecorder()
  169. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+state, nil)
  170. assert.NoError(t, err)
  171. server.router.ServeHTTP(rr, r)
  172. assert.Equal(t, http.StatusFound, rr.Code)
  173. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  174. require.Len(t, oidcMgr.pendingAuths, 0)
  175. rr = httptest.NewRecorder()
  176. r, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil)
  177. assert.NoError(t, err)
  178. server.router.ServeHTTP(rr, r)
  179. assert.Equal(t, http.StatusOK, rr.Code)
  180. // now the same for the web client
  181. rr = httptest.NewRecorder()
  182. r, err = http.NewRequest(http.MethodGet, webClientOIDCLoginPath, nil)
  183. assert.NoError(t, err)
  184. server.router.ServeHTTP(rr, r)
  185. assert.Equal(t, http.StatusFound, rr.Code)
  186. assert.Equal(t, webOIDCRedirectPath, rr.Header().Get("Location"))
  187. require.Len(t, oidcMgr.pendingAuths, 1)
  188. for k := range oidcMgr.pendingAuths {
  189. state = k
  190. }
  191. rr = httptest.NewRecorder()
  192. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+state, nil)
  193. assert.NoError(t, err)
  194. server.router.ServeHTTP(rr, r)
  195. assert.Equal(t, http.StatusFound, rr.Code)
  196. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  197. require.Len(t, oidcMgr.pendingAuths, 0)
  198. rr = httptest.NewRecorder()
  199. r, err = http.NewRequest(http.MethodGet, webClientLoginPath, nil)
  200. assert.NoError(t, err)
  201. server.router.ServeHTTP(rr, r)
  202. assert.Equal(t, http.StatusOK, rr.Code)
  203. // now return an OAuth2 token without the id_token
  204. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  205. tokenSource: &mockTokenSource{},
  206. authCodeURL: webOIDCRedirectPath,
  207. token: &oauth2.Token{
  208. AccessToken: "123",
  209. Expiry: time.Now().Add(5 * time.Minute),
  210. },
  211. err: nil,
  212. }
  213. authReq := newOIDCPendingAuth(tokenAudienceWebClient)
  214. oidcMgr.addPendingAuth(authReq)
  215. rr = httptest.NewRecorder()
  216. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  217. assert.NoError(t, err)
  218. server.router.ServeHTTP(rr, r)
  219. assert.Equal(t, http.StatusFound, rr.Code)
  220. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  221. require.Len(t, oidcMgr.pendingAuths, 0)
  222. // now fail to verify the id token
  223. token := &oauth2.Token{
  224. AccessToken: "123",
  225. Expiry: time.Now().Add(5 * time.Minute),
  226. }
  227. token = token.WithExtra(map[string]any{
  228. "id_token": "id_token_val",
  229. })
  230. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  231. tokenSource: &mockTokenSource{},
  232. authCodeURL: webOIDCRedirectPath,
  233. token: token,
  234. err: nil,
  235. }
  236. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  237. oidcMgr.addPendingAuth(authReq)
  238. rr = httptest.NewRecorder()
  239. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  240. assert.NoError(t, err)
  241. server.router.ServeHTTP(rr, r)
  242. assert.Equal(t, http.StatusFound, rr.Code)
  243. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  244. require.Len(t, oidcMgr.pendingAuths, 0)
  245. // id token nonce does not match
  246. server.binding.OIDC.verifier = &mockOIDCVerifier{
  247. err: nil,
  248. token: &oidc.IDToken{},
  249. }
  250. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  251. oidcMgr.addPendingAuth(authReq)
  252. rr = httptest.NewRecorder()
  253. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  254. assert.NoError(t, err)
  255. server.router.ServeHTTP(rr, r)
  256. assert.Equal(t, http.StatusFound, rr.Code)
  257. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  258. require.Len(t, oidcMgr.pendingAuths, 0)
  259. // null id token claims
  260. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  261. oidcMgr.addPendingAuth(authReq)
  262. server.binding.OIDC.verifier = &mockOIDCVerifier{
  263. err: nil,
  264. token: &oidc.IDToken{
  265. Nonce: authReq.Nonce,
  266. },
  267. }
  268. rr = httptest.NewRecorder()
  269. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  270. assert.NoError(t, err)
  271. server.router.ServeHTTP(rr, r)
  272. assert.Equal(t, http.StatusFound, rr.Code)
  273. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  274. require.Len(t, oidcMgr.pendingAuths, 0)
  275. // invalid id token claims: no username
  276. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  277. oidcMgr.addPendingAuth(authReq)
  278. idToken := &oidc.IDToken{
  279. Nonce: authReq.Nonce,
  280. Expiry: time.Now().Add(5 * time.Minute),
  281. }
  282. setIDTokenClaims(idToken, []byte(`{"aud": "my_client_id"}`))
  283. server.binding.OIDC.verifier = &mockOIDCVerifier{
  284. err: nil,
  285. token: idToken,
  286. }
  287. rr = httptest.NewRecorder()
  288. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  289. assert.NoError(t, err)
  290. server.router.ServeHTTP(rr, r)
  291. assert.Equal(t, http.StatusFound, rr.Code)
  292. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  293. require.Len(t, oidcMgr.pendingAuths, 0)
  294. // invalid id token clamims: username not a string
  295. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  296. oidcMgr.addPendingAuth(authReq)
  297. idToken = &oidc.IDToken{
  298. Nonce: authReq.Nonce,
  299. Expiry: time.Now().Add(5 * time.Minute),
  300. }
  301. setIDTokenClaims(idToken, []byte(`{"aud": "my_client_id","preferred_username": 1}`))
  302. server.binding.OIDC.verifier = &mockOIDCVerifier{
  303. err: nil,
  304. token: idToken,
  305. }
  306. rr = httptest.NewRecorder()
  307. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  308. assert.NoError(t, err)
  309. server.router.ServeHTTP(rr, r)
  310. assert.Equal(t, http.StatusFound, rr.Code)
  311. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  312. require.Len(t, oidcMgr.pendingAuths, 0)
  313. // invalid audience
  314. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  315. oidcMgr.addPendingAuth(authReq)
  316. idToken = &oidc.IDToken{
  317. Nonce: authReq.Nonce,
  318. Expiry: time.Now().Add(5 * time.Minute),
  319. }
  320. setIDTokenClaims(idToken, []byte(`{"preferred_username":"test","sftpgo_role":"admin"}`))
  321. server.binding.OIDC.verifier = &mockOIDCVerifier{
  322. err: nil,
  323. token: idToken,
  324. }
  325. rr = httptest.NewRecorder()
  326. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  327. assert.NoError(t, err)
  328. server.router.ServeHTTP(rr, r)
  329. assert.Equal(t, http.StatusFound, rr.Code)
  330. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  331. require.Len(t, oidcMgr.pendingAuths, 0)
  332. // invalid audience
  333. authReq = newOIDCPendingAuth(tokenAudienceWebAdmin)
  334. oidcMgr.addPendingAuth(authReq)
  335. idToken = &oidc.IDToken{
  336. Nonce: authReq.Nonce,
  337. Expiry: time.Now().Add(5 * time.Minute),
  338. }
  339. setIDTokenClaims(idToken, []byte(`{"preferred_username":"test"}`))
  340. server.binding.OIDC.verifier = &mockOIDCVerifier{
  341. err: nil,
  342. token: idToken,
  343. }
  344. rr = httptest.NewRecorder()
  345. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  346. assert.NoError(t, err)
  347. server.router.ServeHTTP(rr, r)
  348. assert.Equal(t, http.StatusFound, rr.Code)
  349. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  350. require.Len(t, oidcMgr.pendingAuths, 0)
  351. // mapped user not found
  352. authReq = newOIDCPendingAuth(tokenAudienceWebAdmin)
  353. oidcMgr.addPendingAuth(authReq)
  354. idToken = &oidc.IDToken{
  355. Nonce: authReq.Nonce,
  356. Expiry: time.Now().Add(5 * time.Minute),
  357. }
  358. setIDTokenClaims(idToken, []byte(`{"preferred_username":"test","sftpgo_role":"admin"}`))
  359. server.binding.OIDC.verifier = &mockOIDCVerifier{
  360. err: nil,
  361. token: idToken,
  362. }
  363. rr = httptest.NewRecorder()
  364. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  365. assert.NoError(t, err)
  366. server.router.ServeHTTP(rr, r)
  367. assert.Equal(t, http.StatusFound, rr.Code)
  368. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  369. require.Len(t, oidcMgr.pendingAuths, 0)
  370. // admin login ok
  371. authReq = newOIDCPendingAuth(tokenAudienceWebAdmin)
  372. oidcMgr.addPendingAuth(authReq)
  373. idToken = &oidc.IDToken{
  374. Nonce: authReq.Nonce,
  375. Expiry: time.Now().Add(5 * time.Minute),
  376. }
  377. setIDTokenClaims(idToken, []byte(`{"preferred_username":"admin","sftpgo_role":"admin","sid":"sid123"}`))
  378. server.binding.OIDC.verifier = &mockOIDCVerifier{
  379. err: nil,
  380. token: idToken,
  381. }
  382. rr = httptest.NewRecorder()
  383. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  384. assert.NoError(t, err)
  385. server.router.ServeHTTP(rr, r)
  386. assert.Equal(t, http.StatusFound, rr.Code)
  387. assert.Equal(t, webUsersPath, rr.Header().Get("Location"))
  388. require.Len(t, oidcMgr.pendingAuths, 0)
  389. require.Len(t, oidcMgr.tokens, 1)
  390. // admin profile is not available
  391. var tokenCookie string
  392. for k := range oidcMgr.tokens {
  393. tokenCookie = k
  394. }
  395. oidcToken, err := oidcMgr.getToken(tokenCookie)
  396. assert.NoError(t, err)
  397. assert.Equal(t, "sid123", oidcToken.SessionID)
  398. assert.True(t, oidcToken.isAdmin())
  399. assert.False(t, oidcToken.isExpired())
  400. rr = httptest.NewRecorder()
  401. r, err = http.NewRequest(http.MethodGet, webAdminProfilePath, nil)
  402. assert.NoError(t, err)
  403. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  404. server.router.ServeHTTP(rr, r)
  405. assert.Equal(t, http.StatusForbidden, rr.Code)
  406. // the admin can access the allowed pages
  407. rr = httptest.NewRecorder()
  408. r, err = http.NewRequest(http.MethodGet, webUsersPath, nil)
  409. assert.NoError(t, err)
  410. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  411. server.router.ServeHTTP(rr, r)
  412. assert.Equal(t, http.StatusOK, rr.Code)
  413. // try with an invalid cookie
  414. rr = httptest.NewRecorder()
  415. r, err = http.NewRequest(http.MethodGet, webUsersPath, nil)
  416. assert.NoError(t, err)
  417. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, xid.New().String()))
  418. server.router.ServeHTTP(rr, r)
  419. assert.Equal(t, http.StatusFound, rr.Code)
  420. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  421. // Web Client is not available with an admin token
  422. rr = httptest.NewRecorder()
  423. r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
  424. assert.NoError(t, err)
  425. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  426. server.router.ServeHTTP(rr, r)
  427. assert.Equal(t, http.StatusFound, rr.Code)
  428. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  429. // logout the admin user
  430. rr = httptest.NewRecorder()
  431. r, err = http.NewRequest(http.MethodGet, webLogoutPath, nil)
  432. assert.NoError(t, err)
  433. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  434. server.router.ServeHTTP(rr, r)
  435. assert.Equal(t, http.StatusFound, rr.Code)
  436. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  437. require.Len(t, oidcMgr.pendingAuths, 0)
  438. require.Len(t, oidcMgr.tokens, 0)
  439. // now login and logout a user
  440. username := "test_oidc_user"
  441. user := dataprovider.User{
  442. BaseUser: sdk.BaseUser{
  443. Username: username,
  444. Password: "pwd",
  445. HomeDir: filepath.Join(os.TempDir(), username),
  446. Status: 1,
  447. Permissions: map[string][]string{
  448. "/": {dataprovider.PermAny},
  449. },
  450. },
  451. Filters: dataprovider.UserFilters{
  452. BaseUserFilters: sdk.BaseUserFilters{
  453. WebClient: []string{sdk.WebClientSharesDisabled},
  454. },
  455. },
  456. }
  457. err = dataprovider.AddUser(&user, "", "", "")
  458. assert.NoError(t, err)
  459. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  460. oidcMgr.addPendingAuth(authReq)
  461. idToken = &oidc.IDToken{
  462. Nonce: authReq.Nonce,
  463. Expiry: time.Now().Add(5 * time.Minute),
  464. }
  465. setIDTokenClaims(idToken, []byte(`{"preferred_username":"test_oidc_user"}`))
  466. server.binding.OIDC.verifier = &mockOIDCVerifier{
  467. err: nil,
  468. token: idToken,
  469. }
  470. rr = httptest.NewRecorder()
  471. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  472. assert.NoError(t, err)
  473. server.router.ServeHTTP(rr, r)
  474. assert.Equal(t, http.StatusFound, rr.Code)
  475. assert.Equal(t, webClientFilesPath, rr.Header().Get("Location"))
  476. require.Len(t, oidcMgr.pendingAuths, 0)
  477. require.Len(t, oidcMgr.tokens, 1)
  478. // user profile is not available
  479. for k := range oidcMgr.tokens {
  480. tokenCookie = k
  481. }
  482. oidcToken, err = oidcMgr.getToken(tokenCookie)
  483. assert.NoError(t, err)
  484. assert.Empty(t, oidcToken.SessionID)
  485. assert.False(t, oidcToken.isAdmin())
  486. assert.False(t, oidcToken.isExpired())
  487. if assert.Len(t, oidcToken.Permissions, 1) {
  488. assert.Equal(t, sdk.WebClientSharesDisabled, oidcToken.Permissions[0])
  489. }
  490. rr = httptest.NewRecorder()
  491. r, err = http.NewRequest(http.MethodGet, webClientProfilePath, nil)
  492. assert.NoError(t, err)
  493. r.RequestURI = webClientProfilePath
  494. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  495. server.router.ServeHTTP(rr, r)
  496. assert.Equal(t, http.StatusOK, rr.Code)
  497. // the user can access the allowed pages
  498. rr = httptest.NewRecorder()
  499. r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
  500. assert.NoError(t, err)
  501. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  502. server.router.ServeHTTP(rr, r)
  503. assert.Equal(t, http.StatusOK, rr.Code)
  504. // try with an invalid cookie
  505. rr = httptest.NewRecorder()
  506. r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
  507. assert.NoError(t, err)
  508. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, xid.New().String()))
  509. server.router.ServeHTTP(rr, r)
  510. assert.Equal(t, http.StatusFound, rr.Code)
  511. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  512. // Web Admin is not available with a client cookie
  513. rr = httptest.NewRecorder()
  514. r, err = http.NewRequest(http.MethodGet, webUsersPath, nil)
  515. assert.NoError(t, err)
  516. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  517. server.router.ServeHTTP(rr, r)
  518. assert.Equal(t, http.StatusFound, rr.Code)
  519. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  520. // logout the user
  521. rr = httptest.NewRecorder()
  522. r, err = http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
  523. assert.NoError(t, err)
  524. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  525. server.router.ServeHTTP(rr, r)
  526. assert.Equal(t, http.StatusFound, rr.Code)
  527. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  528. require.Len(t, oidcMgr.pendingAuths, 0)
  529. require.Len(t, oidcMgr.tokens, 0)
  530. err = os.RemoveAll(user.GetHomeDir())
  531. assert.NoError(t, err)
  532. err = dataprovider.DeleteUser(username, "", "", "")
  533. assert.NoError(t, err)
  534. }
  535. func TestOIDCRefreshToken(t *testing.T) {
  536. oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
  537. require.True(t, ok)
  538. r, err := http.NewRequest(http.MethodGet, webUsersPath, nil)
  539. assert.NoError(t, err)
  540. token := oidcToken{
  541. Cookie: xid.New().String(),
  542. AccessToken: xid.New().String(),
  543. TokenType: "Bearer",
  544. ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-1 * time.Minute)),
  545. Nonce: xid.New().String(),
  546. Role: adminRoleFieldValue,
  547. Username: defaultAdminUsername,
  548. }
  549. config := mockOAuth2Config{
  550. tokenSource: &mockTokenSource{
  551. err: common.ErrGenericFailure,
  552. },
  553. }
  554. verifier := mockOIDCVerifier{
  555. err: common.ErrGenericFailure,
  556. }
  557. err = token.refresh(context.Background(), &config, &verifier, r)
  558. if assert.Error(t, err) {
  559. assert.Contains(t, err.Error(), "refresh token not set")
  560. }
  561. token.RefreshToken = xid.New().String()
  562. err = token.refresh(context.Background(), &config, &verifier, r)
  563. assert.ErrorIs(t, err, common.ErrGenericFailure)
  564. newToken := &oauth2.Token{
  565. AccessToken: xid.New().String(),
  566. RefreshToken: xid.New().String(),
  567. Expiry: time.Now().Add(5 * time.Minute),
  568. }
  569. config = mockOAuth2Config{
  570. tokenSource: &mockTokenSource{
  571. token: newToken,
  572. },
  573. }
  574. verifier = mockOIDCVerifier{
  575. token: &oidc.IDToken{},
  576. }
  577. err = token.refresh(context.Background(), &config, &verifier, r)
  578. if assert.Error(t, err) {
  579. assert.Contains(t, err.Error(), "the refreshed token has no id token")
  580. }
  581. newToken = newToken.WithExtra(map[string]any{
  582. "id_token": "id_token_val",
  583. })
  584. newToken.Expiry = time.Time{}
  585. config = mockOAuth2Config{
  586. tokenSource: &mockTokenSource{
  587. token: newToken,
  588. },
  589. }
  590. verifier = mockOIDCVerifier{
  591. err: common.ErrGenericFailure,
  592. }
  593. err = token.refresh(context.Background(), &config, &verifier, r)
  594. assert.ErrorIs(t, err, common.ErrGenericFailure)
  595. newToken = newToken.WithExtra(map[string]any{
  596. "id_token": "id_token_val",
  597. })
  598. newToken.Expiry = time.Now().Add(5 * time.Minute)
  599. config = mockOAuth2Config{
  600. tokenSource: &mockTokenSource{
  601. token: newToken,
  602. },
  603. }
  604. verifier = mockOIDCVerifier{
  605. token: &oidc.IDToken{},
  606. }
  607. err = token.refresh(context.Background(), &config, &verifier, r)
  608. if assert.Error(t, err) {
  609. assert.Contains(t, err.Error(), "the refreshed token nonce mismatch")
  610. }
  611. verifier = mockOIDCVerifier{
  612. token: &oidc.IDToken{
  613. Nonce: token.Nonce,
  614. },
  615. }
  616. err = token.refresh(context.Background(), &config, &verifier, r)
  617. if assert.Error(t, err) {
  618. assert.Contains(t, err.Error(), "oidc: claims not set")
  619. }
  620. idToken := &oidc.IDToken{
  621. Nonce: token.Nonce,
  622. }
  623. setIDTokenClaims(idToken, []byte(`{"sid":"id_token_sid"}`))
  624. verifier = mockOIDCVerifier{
  625. token: idToken,
  626. }
  627. err = token.refresh(context.Background(), &config, &verifier, r)
  628. assert.NoError(t, err)
  629. assert.Len(t, token.Permissions, 1)
  630. token.Role = nil
  631. // user does not exist
  632. err = token.refresh(context.Background(), &config, &verifier, r)
  633. assert.Error(t, err)
  634. require.Len(t, oidcMgr.tokens, 1)
  635. oidcMgr.removeToken(token.Cookie)
  636. require.Len(t, oidcMgr.tokens, 0)
  637. }
  638. func TestOIDCRefreshUser(t *testing.T) {
  639. token := oidcToken{
  640. Cookie: xid.New().String(),
  641. AccessToken: xid.New().String(),
  642. TokenType: "Bearer",
  643. ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(1 * time.Minute)),
  644. Nonce: xid.New().String(),
  645. Role: adminRoleFieldValue,
  646. Username: "missing username",
  647. }
  648. r, err := http.NewRequest(http.MethodGet, webUsersPath, nil)
  649. assert.NoError(t, err)
  650. err = token.refreshUser(r)
  651. assert.Error(t, err)
  652. admin := dataprovider.Admin{
  653. Username: "test_oidc_admin_refresh",
  654. Password: "p",
  655. Permissions: []string{dataprovider.PermAdminAny},
  656. Status: 0,
  657. Filters: dataprovider.AdminFilters{
  658. Preferences: dataprovider.AdminPreferences{
  659. HideUserPageSections: 1 + 2 + 4,
  660. },
  661. },
  662. }
  663. err = dataprovider.AddAdmin(&admin, "", "", "")
  664. assert.NoError(t, err)
  665. token.Username = admin.Username
  666. err = token.refreshUser(r)
  667. if assert.Error(t, err) {
  668. assert.Contains(t, err.Error(), "is disabled")
  669. }
  670. admin.Status = 1
  671. err = dataprovider.UpdateAdmin(&admin, "", "", "")
  672. assert.NoError(t, err)
  673. err = token.refreshUser(r)
  674. assert.NoError(t, err)
  675. assert.Equal(t, admin.Permissions, token.Permissions)
  676. assert.Equal(t, admin.Filters.Preferences.HideUserPageSections, token.HideUserPageSections)
  677. err = dataprovider.DeleteAdmin(admin.Username, "", "", "")
  678. assert.NoError(t, err)
  679. username := "test_oidc_user_refresh_token"
  680. user := dataprovider.User{
  681. BaseUser: sdk.BaseUser{
  682. Username: username,
  683. Password: "p",
  684. HomeDir: filepath.Join(os.TempDir(), username),
  685. Status: 0,
  686. Permissions: map[string][]string{
  687. "/": {dataprovider.PermAny},
  688. },
  689. },
  690. Filters: dataprovider.UserFilters{
  691. BaseUserFilters: sdk.BaseUserFilters{
  692. DeniedProtocols: []string{common.ProtocolHTTP},
  693. WebClient: []string{sdk.WebClientSharesDisabled, sdk.WebClientWriteDisabled},
  694. },
  695. },
  696. }
  697. err = dataprovider.AddUser(&user, "", "", "")
  698. assert.NoError(t, err)
  699. r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
  700. assert.NoError(t, err)
  701. token.Role = nil
  702. token.Username = username
  703. assert.False(t, token.isAdmin())
  704. err = token.refreshUser(r)
  705. if assert.Error(t, err) {
  706. assert.Contains(t, err.Error(), "is disabled")
  707. }
  708. user, err = dataprovider.UserExists(username, "")
  709. assert.NoError(t, err)
  710. user.Status = 1
  711. err = dataprovider.UpdateUser(&user, "", "", "")
  712. assert.NoError(t, err)
  713. err = token.refreshUser(r)
  714. if assert.Error(t, err) {
  715. assert.Contains(t, err.Error(), "protocol HTTP is not allowed")
  716. }
  717. user.Filters.DeniedProtocols = []string{common.ProtocolFTP}
  718. err = dataprovider.UpdateUser(&user, "", "", "")
  719. assert.NoError(t, err)
  720. err = token.refreshUser(r)
  721. assert.NoError(t, err)
  722. assert.Equal(t, user.Filters.WebClient, token.Permissions)
  723. err = dataprovider.DeleteUser(username, "", "", "")
  724. assert.NoError(t, err)
  725. }
  726. func TestValidateOIDCToken(t *testing.T) {
  727. oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
  728. require.True(t, ok)
  729. server := getTestOIDCServer()
  730. err := server.binding.OIDC.initialize()
  731. assert.NoError(t, err)
  732. server.initializeRouter()
  733. rr := httptest.NewRecorder()
  734. r, err := http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
  735. assert.NoError(t, err)
  736. _, err = server.validateOIDCToken(rr, r, false)
  737. assert.ErrorIs(t, err, errInvalidToken)
  738. // expired token and refresh error
  739. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  740. tokenSource: &mockTokenSource{
  741. err: common.ErrGenericFailure,
  742. },
  743. }
  744. token := oidcToken{
  745. Cookie: xid.New().String(),
  746. AccessToken: xid.New().String(),
  747. ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-2 * time.Minute)),
  748. }
  749. oidcMgr.addToken(token)
  750. rr = httptest.NewRecorder()
  751. r, err = http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
  752. assert.NoError(t, err)
  753. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, token.Cookie))
  754. _, err = server.validateOIDCToken(rr, r, false)
  755. assert.ErrorIs(t, err, errInvalidToken)
  756. oidcMgr.removeToken(token.Cookie)
  757. assert.Len(t, oidcMgr.tokens, 0)
  758. server.tokenAuth = jwtauth.New("PS256", util.GenerateRandomBytes(32), nil)
  759. token = oidcToken{
  760. Cookie: xid.New().String(),
  761. AccessToken: xid.New().String(),
  762. }
  763. oidcMgr.addToken(token)
  764. rr = httptest.NewRecorder()
  765. r, err = http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
  766. assert.NoError(t, err)
  767. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, token.Cookie))
  768. server.router.ServeHTTP(rr, r)
  769. assert.Equal(t, http.StatusFound, rr.Code)
  770. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  771. oidcMgr.removeToken(token.Cookie)
  772. assert.Len(t, oidcMgr.tokens, 0)
  773. token = oidcToken{
  774. Cookie: xid.New().String(),
  775. AccessToken: xid.New().String(),
  776. Role: "admin",
  777. }
  778. oidcMgr.addToken(token)
  779. rr = httptest.NewRecorder()
  780. r, err = http.NewRequest(http.MethodGet, webLogoutPath, nil)
  781. assert.NoError(t, err)
  782. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, token.Cookie))
  783. server.router.ServeHTTP(rr, r)
  784. assert.Equal(t, http.StatusFound, rr.Code)
  785. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  786. oidcMgr.removeToken(token.Cookie)
  787. assert.Len(t, oidcMgr.tokens, 0)
  788. }
  789. func TestSkipOIDCAuth(t *testing.T) {
  790. server := getTestOIDCServer()
  791. err := server.binding.OIDC.initialize()
  792. assert.NoError(t, err)
  793. server.initializeRouter()
  794. jwtTokenClaims := jwtTokenClaims{
  795. Username: "user",
  796. }
  797. _, tokenString, err := jwtTokenClaims.createToken(server.tokenAuth, tokenAudienceWebClient, "")
  798. assert.NoError(t, err)
  799. rr := httptest.NewRecorder()
  800. r, err := http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
  801. assert.NoError(t, err)
  802. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", jwtCookieKey, tokenString))
  803. server.router.ServeHTTP(rr, r)
  804. assert.Equal(t, http.StatusFound, rr.Code)
  805. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  806. }
  807. func TestOIDCLogoutErrors(t *testing.T) {
  808. server := getTestOIDCServer()
  809. assert.Empty(t, server.binding.OIDC.providerLogoutURL)
  810. server.logoutFromOIDCOP("")
  811. server.binding.OIDC.providerLogoutURL = "http://foo\x7f.com/"
  812. server.doOIDCFromLogout("")
  813. server.binding.OIDC.providerLogoutURL = "http://127.0.0.1:11234"
  814. server.doOIDCFromLogout("")
  815. }
  816. func TestOIDCToken(t *testing.T) {
  817. admin := dataprovider.Admin{
  818. Username: "test_oidc_admin",
  819. Password: "p",
  820. Permissions: []string{dataprovider.PermAdminAny},
  821. Status: 0,
  822. }
  823. err := dataprovider.AddAdmin(&admin, "", "", "")
  824. assert.NoError(t, err)
  825. token := oidcToken{
  826. Username: admin.Username,
  827. }
  828. // role not initialized, user with the specified username does not exist
  829. req, err := http.NewRequest(http.MethodGet, webUsersPath, nil)
  830. assert.NoError(t, err)
  831. err = token.getUser(req)
  832. assert.ErrorIs(t, err, util.ErrNotFound)
  833. token.Role = "admin"
  834. req, err = http.NewRequest(http.MethodGet, webUsersPath, nil)
  835. assert.NoError(t, err)
  836. err = token.getUser(req)
  837. if assert.Error(t, err) {
  838. assert.Contains(t, err.Error(), "is disabled")
  839. }
  840. err = dataprovider.DeleteAdmin(admin.Username, "", "", "")
  841. assert.NoError(t, err)
  842. username := "test_oidc_user"
  843. token.Username = username
  844. token.Role = ""
  845. err = token.getUser(req)
  846. if assert.Error(t, err) {
  847. assert.ErrorIs(t, err, util.ErrNotFound)
  848. }
  849. user := dataprovider.User{
  850. BaseUser: sdk.BaseUser{
  851. Username: username,
  852. Password: "p",
  853. HomeDir: filepath.Join(os.TempDir(), username),
  854. Status: 0,
  855. Permissions: map[string][]string{
  856. "/": {dataprovider.PermAny},
  857. },
  858. },
  859. Filters: dataprovider.UserFilters{
  860. BaseUserFilters: sdk.BaseUserFilters{
  861. DeniedProtocols: []string{common.ProtocolHTTP},
  862. },
  863. },
  864. }
  865. err = dataprovider.AddUser(&user, "", "", "")
  866. assert.NoError(t, err)
  867. err = token.getUser(req)
  868. if assert.Error(t, err) {
  869. assert.Contains(t, err.Error(), "is disabled")
  870. }
  871. user, err = dataprovider.UserExists(username, "")
  872. assert.NoError(t, err)
  873. user.Status = 1
  874. user.Password = "np"
  875. err = dataprovider.UpdateUser(&user, "", "", "")
  876. assert.NoError(t, err)
  877. err = token.getUser(req)
  878. if assert.Error(t, err) {
  879. assert.Contains(t, err.Error(), "protocol HTTP is not allowed")
  880. }
  881. user.Filters.DeniedProtocols = nil
  882. user.FsConfig.Provider = sdk.SFTPFilesystemProvider
  883. user.FsConfig.SFTPConfig = vfs.SFTPFsConfig{
  884. BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{
  885. Endpoint: "127.0.0.1:8022",
  886. Username: username,
  887. },
  888. Password: kms.NewPlainSecret("np"),
  889. }
  890. err = dataprovider.UpdateUser(&user, "", "", "")
  891. assert.NoError(t, err)
  892. err = token.getUser(req)
  893. if assert.Error(t, err) {
  894. assert.Contains(t, err.Error(), "SFTP loop")
  895. }
  896. common.Config.PostConnectHook = fmt.Sprintf("http://%v/404", oidcMockAddr)
  897. err = token.getUser(req)
  898. if assert.Error(t, err) {
  899. assert.Contains(t, err.Error(), "access denied")
  900. }
  901. common.Config.PostConnectHook = ""
  902. err = os.RemoveAll(user.GetHomeDir())
  903. assert.NoError(t, err)
  904. err = dataprovider.DeleteUser(username, "", "", "")
  905. assert.NoError(t, err)
  906. }
  907. func TestOIDCImplicitRoles(t *testing.T) {
  908. oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
  909. require.True(t, ok)
  910. server := getTestOIDCServer()
  911. server.binding.OIDC.ImplicitRoles = true
  912. err := server.binding.OIDC.initialize()
  913. assert.NoError(t, err)
  914. server.initializeRouter()
  915. authReq := newOIDCPendingAuth(tokenAudienceWebAdmin)
  916. oidcMgr.addPendingAuth(authReq)
  917. token := &oauth2.Token{
  918. AccessToken: "1234",
  919. Expiry: time.Now().Add(5 * time.Minute),
  920. }
  921. token = token.WithExtra(map[string]any{
  922. "id_token": "id_token_val",
  923. })
  924. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  925. tokenSource: &mockTokenSource{},
  926. authCodeURL: webOIDCRedirectPath,
  927. token: token,
  928. }
  929. idToken := &oidc.IDToken{
  930. Nonce: authReq.Nonce,
  931. Expiry: time.Now().Add(5 * time.Minute),
  932. }
  933. setIDTokenClaims(idToken, []byte(`{"preferred_username":"admin","sid":"sid456"}`))
  934. server.binding.OIDC.verifier = &mockOIDCVerifier{
  935. err: nil,
  936. token: idToken,
  937. }
  938. rr := httptest.NewRecorder()
  939. r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  940. assert.NoError(t, err)
  941. server.router.ServeHTTP(rr, r)
  942. assert.Equal(t, http.StatusFound, rr.Code)
  943. assert.Equal(t, webUsersPath, rr.Header().Get("Location"))
  944. require.Len(t, oidcMgr.pendingAuths, 0)
  945. require.Len(t, oidcMgr.tokens, 1)
  946. var tokenCookie string
  947. for k := range oidcMgr.tokens {
  948. tokenCookie = k
  949. }
  950. // Web Client is not available with an admin token
  951. rr = httptest.NewRecorder()
  952. r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
  953. assert.NoError(t, err)
  954. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  955. server.router.ServeHTTP(rr, r)
  956. assert.Equal(t, http.StatusFound, rr.Code)
  957. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  958. // logout the admin user
  959. rr = httptest.NewRecorder()
  960. r, err = http.NewRequest(http.MethodGet, webLogoutPath, nil)
  961. assert.NoError(t, err)
  962. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  963. server.router.ServeHTTP(rr, r)
  964. assert.Equal(t, http.StatusFound, rr.Code)
  965. assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location"))
  966. require.Len(t, oidcMgr.pendingAuths, 0)
  967. require.Len(t, oidcMgr.tokens, 0)
  968. // now login and logout a user
  969. username := "test_oidc_implicit_user"
  970. user := dataprovider.User{
  971. BaseUser: sdk.BaseUser{
  972. Username: username,
  973. Password: "pwd",
  974. HomeDir: filepath.Join(os.TempDir(), username),
  975. Status: 1,
  976. Permissions: map[string][]string{
  977. "/": {dataprovider.PermAny},
  978. },
  979. },
  980. Filters: dataprovider.UserFilters{
  981. BaseUserFilters: sdk.BaseUserFilters{
  982. WebClient: []string{sdk.WebClientSharesDisabled},
  983. },
  984. },
  985. }
  986. err = dataprovider.AddUser(&user, "", "", "")
  987. assert.NoError(t, err)
  988. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  989. oidcMgr.addPendingAuth(authReq)
  990. idToken = &oidc.IDToken{
  991. Nonce: authReq.Nonce,
  992. Expiry: time.Now().Add(5 * time.Minute),
  993. }
  994. setIDTokenClaims(idToken, []byte(`{"preferred_username":"test_oidc_implicit_user"}`))
  995. server.binding.OIDC.verifier = &mockOIDCVerifier{
  996. err: nil,
  997. token: idToken,
  998. }
  999. rr = httptest.NewRecorder()
  1000. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  1001. assert.NoError(t, err)
  1002. server.router.ServeHTTP(rr, r)
  1003. assert.Equal(t, http.StatusFound, rr.Code)
  1004. assert.Equal(t, webClientFilesPath, rr.Header().Get("Location"))
  1005. require.Len(t, oidcMgr.pendingAuths, 0)
  1006. require.Len(t, oidcMgr.tokens, 1)
  1007. for k := range oidcMgr.tokens {
  1008. tokenCookie = k
  1009. }
  1010. rr = httptest.NewRecorder()
  1011. r, err = http.NewRequest(http.MethodGet, webClientLogoutPath, nil)
  1012. assert.NoError(t, err)
  1013. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  1014. server.router.ServeHTTP(rr, r)
  1015. assert.Equal(t, http.StatusFound, rr.Code)
  1016. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  1017. require.Len(t, oidcMgr.pendingAuths, 0)
  1018. require.Len(t, oidcMgr.tokens, 0)
  1019. err = os.RemoveAll(user.GetHomeDir())
  1020. assert.NoError(t, err)
  1021. err = dataprovider.DeleteUser(username, "", "", "")
  1022. assert.NoError(t, err)
  1023. }
  1024. func TestMemoryOIDCManager(t *testing.T) {
  1025. oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
  1026. require.True(t, ok)
  1027. require.Len(t, oidcMgr.pendingAuths, 0)
  1028. authReq := newOIDCPendingAuth(tokenAudienceWebAdmin)
  1029. oidcMgr.addPendingAuth(authReq)
  1030. require.Len(t, oidcMgr.pendingAuths, 1)
  1031. _, err := oidcMgr.getPendingAuth(authReq.State)
  1032. assert.NoError(t, err)
  1033. oidcMgr.removePendingAuth(authReq.State)
  1034. require.Len(t, oidcMgr.pendingAuths, 0)
  1035. authReq.IssuedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-61 * time.Second))
  1036. oidcMgr.addPendingAuth(authReq)
  1037. require.Len(t, oidcMgr.pendingAuths, 1)
  1038. _, err = oidcMgr.getPendingAuth(authReq.State)
  1039. if assert.Error(t, err) {
  1040. assert.Contains(t, err.Error(), "too old")
  1041. }
  1042. oidcMgr.cleanup()
  1043. require.Len(t, oidcMgr.pendingAuths, 0)
  1044. token := oidcToken{
  1045. AccessToken: xid.New().String(),
  1046. Nonce: xid.New().String(),
  1047. SessionID: xid.New().String(),
  1048. Cookie: xid.New().String(),
  1049. Username: xid.New().String(),
  1050. Role: "admin",
  1051. Permissions: []string{dataprovider.PermAdminAny},
  1052. }
  1053. require.Len(t, oidcMgr.tokens, 0)
  1054. oidcMgr.addToken(token)
  1055. require.Len(t, oidcMgr.tokens, 1)
  1056. _, err = oidcMgr.getToken(xid.New().String())
  1057. assert.Error(t, err)
  1058. storedToken, err := oidcMgr.getToken(token.Cookie)
  1059. assert.NoError(t, err)
  1060. token.UsedAt = 0 // ensure we don't modify the stored token
  1061. assert.Greater(t, storedToken.UsedAt, int64(0))
  1062. token.UsedAt = storedToken.UsedAt
  1063. assert.Equal(t, token, storedToken)
  1064. // the usage will not be updated, it is recent
  1065. oidcMgr.updateTokenUsage(storedToken)
  1066. storedToken, err = oidcMgr.getToken(token.Cookie)
  1067. assert.NoError(t, err)
  1068. assert.Equal(t, token, storedToken)
  1069. usedAt := util.GetTimeAsMsSinceEpoch(time.Now().Add(-5 * time.Minute))
  1070. storedToken.UsedAt = usedAt
  1071. oidcMgr.tokens[token.Cookie] = storedToken
  1072. storedToken, err = oidcMgr.getToken(token.Cookie)
  1073. assert.NoError(t, err)
  1074. assert.Equal(t, usedAt, storedToken.UsedAt)
  1075. token.UsedAt = storedToken.UsedAt
  1076. assert.Equal(t, token, storedToken)
  1077. oidcMgr.updateTokenUsage(storedToken)
  1078. storedToken, err = oidcMgr.getToken(token.Cookie)
  1079. assert.NoError(t, err)
  1080. assert.Greater(t, storedToken.UsedAt, usedAt)
  1081. token.UsedAt = storedToken.UsedAt
  1082. assert.Equal(t, token, storedToken)
  1083. storedToken.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now()) - tokenDeleteInterval - 1
  1084. oidcMgr.tokens[token.Cookie] = storedToken
  1085. storedToken, err = oidcMgr.getToken(token.Cookie)
  1086. if assert.Error(t, err) {
  1087. assert.Contains(t, err.Error(), "token is too old")
  1088. }
  1089. oidcMgr.removeToken(xid.New().String())
  1090. require.Len(t, oidcMgr.tokens, 1)
  1091. oidcMgr.removeToken(token.Cookie)
  1092. require.Len(t, oidcMgr.tokens, 0)
  1093. oidcMgr.addToken(token)
  1094. usedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-6 * time.Hour))
  1095. token.UsedAt = usedAt
  1096. oidcMgr.tokens[token.Cookie] = token
  1097. newToken := oidcToken{
  1098. Cookie: xid.New().String(),
  1099. }
  1100. oidcMgr.addToken(newToken)
  1101. oidcMgr.cleanup()
  1102. require.Len(t, oidcMgr.tokens, 1)
  1103. _, err = oidcMgr.getToken(token.Cookie)
  1104. assert.Error(t, err)
  1105. _, err = oidcMgr.getToken(newToken.Cookie)
  1106. assert.NoError(t, err)
  1107. oidcMgr.removeToken(newToken.Cookie)
  1108. require.Len(t, oidcMgr.tokens, 0)
  1109. }
  1110. func TestOIDCEvMgrIntegration(t *testing.T) {
  1111. providerConf := dataprovider.GetProviderConfig()
  1112. err := dataprovider.Close()
  1113. assert.NoError(t, err)
  1114. newProviderConf := providerConf
  1115. newProviderConf.NamingRules = 5
  1116. err = dataprovider.Initialize(newProviderConf, configDir, true)
  1117. assert.NoError(t, err)
  1118. // add a special chars to check json replacer
  1119. username := `test_"oidc_eventmanager`
  1120. u := map[string]any{
  1121. "username": "{{Name}}",
  1122. "status": 1,
  1123. "home_dir": filepath.Join(os.TempDir(), "{{IDPFieldcustom1.sub}}"),
  1124. "permissions": map[string][]string{
  1125. "/": {dataprovider.PermAny},
  1126. },
  1127. "description": "{{IDPFieldcustom2}}",
  1128. }
  1129. userTmpl, err := json.Marshal(u)
  1130. require.NoError(t, err)
  1131. a := map[string]any{
  1132. "username": "{{Name}}",
  1133. "status": 1,
  1134. "permissions": []string{dataprovider.PermAdminAny},
  1135. }
  1136. adminTmpl, err := json.Marshal(a)
  1137. require.NoError(t, err)
  1138. action := &dataprovider.BaseEventAction{
  1139. Name: "a",
  1140. Type: dataprovider.ActionTypeIDPAccountCheck,
  1141. Options: dataprovider.BaseEventActionOptions{
  1142. IDPConfig: dataprovider.EventActionIDPAccountCheck{
  1143. Mode: 0,
  1144. TemplateUser: string(userTmpl),
  1145. TemplateAdmin: string(adminTmpl),
  1146. },
  1147. },
  1148. }
  1149. err = dataprovider.AddEventAction(action, "", "", "")
  1150. assert.NoError(t, err)
  1151. rule := &dataprovider.EventRule{
  1152. Name: "r",
  1153. Status: 1,
  1154. Trigger: dataprovider.EventTriggerIDPLogin,
  1155. Conditions: dataprovider.EventConditions{
  1156. IDPLoginEvent: 0,
  1157. },
  1158. Actions: []dataprovider.EventAction{
  1159. {
  1160. BaseEventAction: dataprovider.BaseEventAction{
  1161. Name: action.Name,
  1162. },
  1163. Options: dataprovider.EventActionOptions{
  1164. ExecuteSync: true,
  1165. },
  1166. },
  1167. },
  1168. }
  1169. err = dataprovider.AddEventRule(rule, "", "", "")
  1170. assert.NoError(t, err)
  1171. oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
  1172. require.True(t, ok)
  1173. server := getTestOIDCServer()
  1174. server.binding.OIDC.ImplicitRoles = true
  1175. server.binding.OIDC.CustomFields = []string{"custom1.sub", "custom2"}
  1176. err = server.binding.OIDC.initialize()
  1177. assert.NoError(t, err)
  1178. server.initializeRouter()
  1179. // login a user with OIDC
  1180. _, err = dataprovider.UserExists(username, "")
  1181. assert.ErrorIs(t, err, util.ErrNotFound)
  1182. authReq := newOIDCPendingAuth(tokenAudienceWebClient)
  1183. oidcMgr.addPendingAuth(authReq)
  1184. token := &oauth2.Token{
  1185. AccessToken: "1234",
  1186. Expiry: time.Now().Add(5 * time.Minute),
  1187. }
  1188. token = token.WithExtra(map[string]any{
  1189. "id_token": "id_token_val",
  1190. })
  1191. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  1192. tokenSource: &mockTokenSource{},
  1193. authCodeURL: webOIDCRedirectPath,
  1194. token: token,
  1195. }
  1196. idToken := &oidc.IDToken{
  1197. Nonce: authReq.Nonce,
  1198. Expiry: time.Now().Add(5 * time.Minute),
  1199. }
  1200. setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+util.JSONEscape(username)+`","custom1":{"sub":"val1"},"custom2":"desc"}`)) //nolint:goconst
  1201. server.binding.OIDC.verifier = &mockOIDCVerifier{
  1202. err: nil,
  1203. token: idToken,
  1204. }
  1205. rr := httptest.NewRecorder()
  1206. r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  1207. assert.NoError(t, err)
  1208. server.router.ServeHTTP(rr, r)
  1209. assert.Equal(t, http.StatusFound, rr.Code)
  1210. assert.Equal(t, webClientFilesPath, rr.Header().Get("Location"))
  1211. user, err := dataprovider.UserExists(username, "")
  1212. assert.NoError(t, err)
  1213. assert.Equal(t, filepath.Join(os.TempDir(), "val1"), user.GetHomeDir())
  1214. assert.Equal(t, "desc", user.Description)
  1215. err = dataprovider.DeleteUser(username, "", "", "")
  1216. assert.NoError(t, err)
  1217. err = os.RemoveAll(user.GetHomeDir())
  1218. assert.NoError(t, err)
  1219. // login an admin with OIDC
  1220. _, err = dataprovider.AdminExists(username)
  1221. assert.ErrorIs(t, err, util.ErrNotFound)
  1222. authReq = newOIDCPendingAuth(tokenAudienceWebAdmin)
  1223. oidcMgr.addPendingAuth(authReq)
  1224. idToken = &oidc.IDToken{
  1225. Nonce: authReq.Nonce,
  1226. Expiry: time.Now().Add(5 * time.Minute),
  1227. }
  1228. setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+util.JSONEscape(username)+`"}`))
  1229. server.binding.OIDC.verifier = &mockOIDCVerifier{
  1230. err: nil,
  1231. token: idToken,
  1232. }
  1233. rr = httptest.NewRecorder()
  1234. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  1235. assert.NoError(t, err)
  1236. server.router.ServeHTTP(rr, r)
  1237. assert.Equal(t, http.StatusFound, rr.Code)
  1238. assert.Equal(t, webUsersPath, rr.Header().Get("Location"))
  1239. _, err = dataprovider.AdminExists(username)
  1240. assert.NoError(t, err)
  1241. err = dataprovider.DeleteAdmin(username, "", "", "")
  1242. assert.NoError(t, err)
  1243. // set invalid templates and try again
  1244. action.Options.IDPConfig.TemplateUser = `{}`
  1245. action.Options.IDPConfig.TemplateAdmin = `{}`
  1246. err = dataprovider.UpdateEventAction(action, "", "", "")
  1247. assert.NoError(t, err)
  1248. for _, audience := range []string{tokenAudienceWebAdmin, tokenAudienceWebClient} {
  1249. authReq = newOIDCPendingAuth(audience)
  1250. oidcMgr.addPendingAuth(authReq)
  1251. idToken = &oidc.IDToken{
  1252. Nonce: authReq.Nonce,
  1253. Expiry: time.Now().Add(5 * time.Minute),
  1254. }
  1255. setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+util.JSONEscape(username)+`"}`))
  1256. server.binding.OIDC.verifier = &mockOIDCVerifier{
  1257. err: nil,
  1258. token: idToken,
  1259. }
  1260. rr = httptest.NewRecorder()
  1261. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  1262. assert.NoError(t, err)
  1263. server.router.ServeHTTP(rr, r)
  1264. assert.Equal(t, http.StatusFound, rr.Code)
  1265. }
  1266. for k := range oidcMgr.tokens {
  1267. oidcMgr.removeToken(k)
  1268. }
  1269. err = dataprovider.DeleteEventRule(rule.Name, "", "", "")
  1270. assert.NoError(t, err)
  1271. err = dataprovider.DeleteEventAction(action.Name, "", "", "")
  1272. assert.NoError(t, err)
  1273. err = dataprovider.Close()
  1274. assert.NoError(t, err)
  1275. err = dataprovider.Initialize(providerConf, configDir, true)
  1276. assert.NoError(t, err)
  1277. }
  1278. func TestOIDCPreLoginHook(t *testing.T) {
  1279. if runtime.GOOS == osWindows {
  1280. t.Skip("this test is not available on Windows")
  1281. }
  1282. oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
  1283. require.True(t, ok)
  1284. username := "test_oidc_user_prelogin"
  1285. u := dataprovider.User{
  1286. BaseUser: sdk.BaseUser{
  1287. Username: username,
  1288. HomeDir: filepath.Join(os.TempDir(), username),
  1289. Status: 1,
  1290. Permissions: map[string][]string{
  1291. "/": {dataprovider.PermAny},
  1292. },
  1293. },
  1294. }
  1295. preLoginPath := filepath.Join(os.TempDir(), "prelogin.sh")
  1296. providerConf := dataprovider.GetProviderConfig()
  1297. err := dataprovider.Close()
  1298. assert.NoError(t, err)
  1299. err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm)
  1300. assert.NoError(t, err)
  1301. newProviderConf := providerConf
  1302. newProviderConf.PreLoginHook = preLoginPath
  1303. err = dataprovider.Initialize(newProviderConf, configDir, true)
  1304. assert.NoError(t, err)
  1305. server := getTestOIDCServer()
  1306. server.binding.OIDC.CustomFields = []string{"field1", "field2"}
  1307. err = server.binding.OIDC.initialize()
  1308. assert.NoError(t, err)
  1309. server.initializeRouter()
  1310. _, err = dataprovider.UserExists(username, "")
  1311. assert.ErrorIs(t, err, util.ErrNotFound)
  1312. // now login with OIDC
  1313. authReq := newOIDCPendingAuth(tokenAudienceWebClient)
  1314. oidcMgr.addPendingAuth(authReq)
  1315. token := &oauth2.Token{
  1316. AccessToken: "1234",
  1317. Expiry: time.Now().Add(5 * time.Minute),
  1318. }
  1319. token = token.WithExtra(map[string]any{
  1320. "id_token": "id_token_val",
  1321. })
  1322. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  1323. tokenSource: &mockTokenSource{},
  1324. authCodeURL: webOIDCRedirectPath,
  1325. token: token,
  1326. }
  1327. idToken := &oidc.IDToken{
  1328. Nonce: authReq.Nonce,
  1329. Expiry: time.Now().Add(5 * time.Minute),
  1330. }
  1331. setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+username+`"}`))
  1332. server.binding.OIDC.verifier = &mockOIDCVerifier{
  1333. err: nil,
  1334. token: idToken,
  1335. }
  1336. rr := httptest.NewRecorder()
  1337. r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  1338. assert.NoError(t, err)
  1339. server.router.ServeHTTP(rr, r)
  1340. assert.Equal(t, http.StatusFound, rr.Code)
  1341. assert.Equal(t, webClientFilesPath, rr.Header().Get("Location"))
  1342. _, err = dataprovider.UserExists(username, "")
  1343. assert.NoError(t, err)
  1344. err = dataprovider.DeleteUser(username, "", "", "")
  1345. assert.NoError(t, err)
  1346. err = os.RemoveAll(u.HomeDir)
  1347. assert.NoError(t, err)
  1348. err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, true), os.ModePerm)
  1349. assert.NoError(t, err)
  1350. authReq = newOIDCPendingAuth(tokenAudienceWebClient)
  1351. oidcMgr.addPendingAuth(authReq)
  1352. idToken = &oidc.IDToken{
  1353. Nonce: authReq.Nonce,
  1354. Expiry: time.Now().Add(5 * time.Minute),
  1355. }
  1356. setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+username+`","field1":"value1","field2":"value2","field3":"value3"}`))
  1357. server.binding.OIDC.verifier = &mockOIDCVerifier{
  1358. err: nil,
  1359. token: idToken,
  1360. }
  1361. rr = httptest.NewRecorder()
  1362. r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  1363. assert.NoError(t, err)
  1364. server.router.ServeHTTP(rr, r)
  1365. assert.Equal(t, http.StatusFound, rr.Code)
  1366. assert.Equal(t, webClientLoginPath, rr.Header().Get("Location"))
  1367. _, err = dataprovider.UserExists(username, "")
  1368. assert.ErrorIs(t, err, util.ErrNotFound)
  1369. if assert.Len(t, oidcMgr.tokens, 1) {
  1370. for k := range oidcMgr.tokens {
  1371. oidcMgr.removeToken(k)
  1372. }
  1373. }
  1374. require.Len(t, oidcMgr.pendingAuths, 0)
  1375. require.Len(t, oidcMgr.tokens, 0)
  1376. err = dataprovider.Close()
  1377. assert.NoError(t, err)
  1378. err = dataprovider.Initialize(providerConf, configDir, true)
  1379. assert.NoError(t, err)
  1380. err = os.Remove(preLoginPath)
  1381. assert.NoError(t, err)
  1382. }
  1383. func TestOIDCIsAdmin(t *testing.T) {
  1384. type test struct {
  1385. input any
  1386. want bool
  1387. }
  1388. emptySlice := make([]any, 0)
  1389. tests := []test{
  1390. {input: "admin", want: true},
  1391. {input: append(emptySlice, "admin"), want: true},
  1392. {input: append(emptySlice, "user", "admin"), want: true},
  1393. {input: "user", want: false},
  1394. {input: emptySlice, want: false},
  1395. {input: append(emptySlice, 1), want: false},
  1396. {input: 1, want: false},
  1397. {input: nil, want: false},
  1398. {input: map[string]string{"admin": "admin"}, want: false},
  1399. }
  1400. for _, tc := range tests {
  1401. token := oidcToken{
  1402. Role: tc.input,
  1403. }
  1404. assert.Equal(t, tc.want, token.isAdmin(), "%v should return %t", tc.input, tc.want)
  1405. }
  1406. }
  1407. func TestParseAdminRole(t *testing.T) {
  1408. claims := make(map[string]any)
  1409. rawClaims := []byte(`{
  1410. "sub": "35666371",
  1411. "email": "example@example.com",
  1412. "preferred_username": "Sally",
  1413. "name": "Sally Tyler",
  1414. "updated_at": "2018-04-13T22:08:45Z",
  1415. "given_name": "Sally",
  1416. "family_name": "Tyler",
  1417. "params": {
  1418. "sftpgo_role": "admin",
  1419. "subparams": {
  1420. "sftpgo_role": "admin",
  1421. "inner": {
  1422. "sftpgo_role": ["user","admin"]
  1423. }
  1424. }
  1425. },
  1426. "at_hash": "lPLhxI2wjEndc-WfyroDZA",
  1427. "rt_hash": "mCmxPtA04N-55AxlEUbq-A",
  1428. "aud": "78d1d040-20c9-0136-5146-067351775fae92920",
  1429. "exp": 1523664997,
  1430. "iat": 1523657797
  1431. }`)
  1432. err := json.Unmarshal(rawClaims, &claims)
  1433. assert.NoError(t, err)
  1434. type test struct {
  1435. input string
  1436. want bool
  1437. val any
  1438. }
  1439. tests := []test{
  1440. {input: "", want: false},
  1441. {input: "sftpgo_role", want: false},
  1442. {input: "params.sftpgo_role", want: true, val: "admin"},
  1443. {input: "params.subparams.sftpgo_role", want: true, val: "admin"},
  1444. {input: "params.subparams.inner.sftpgo_role", want: true, val: []any{"user", "admin"}},
  1445. {input: "email", want: false},
  1446. {input: "missing", want: false},
  1447. {input: "params.email", want: false},
  1448. {input: "missing.sftpgo_role", want: false},
  1449. {input: "params", want: false},
  1450. {input: "params.subparams.inner.sftpgo_role.missing", want: false},
  1451. }
  1452. for _, tc := range tests {
  1453. token := oidcToken{}
  1454. token.getRoleFromField(claims, tc.input)
  1455. assert.Equal(t, tc.want, token.isAdmin(), "%q should return %t", tc.input, tc.want)
  1456. if tc.want {
  1457. assert.Equal(t, tc.val, token.Role)
  1458. }
  1459. }
  1460. }
  1461. func TestOIDCWithLoginFormsDisabled(t *testing.T) {
  1462. oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
  1463. require.True(t, ok)
  1464. server := getTestOIDCServer()
  1465. server.binding.OIDC.ImplicitRoles = true
  1466. server.binding.EnabledLoginMethods = 3
  1467. server.binding.EnableWebAdmin = true
  1468. server.binding.EnableWebClient = true
  1469. err := server.binding.OIDC.initialize()
  1470. assert.NoError(t, err)
  1471. server.initializeRouter()
  1472. // login with an admin user
  1473. authReq := newOIDCPendingAuth(tokenAudienceWebAdmin)
  1474. oidcMgr.addPendingAuth(authReq)
  1475. token := &oauth2.Token{
  1476. AccessToken: "1234",
  1477. Expiry: time.Now().Add(5 * time.Minute),
  1478. }
  1479. token = token.WithExtra(map[string]any{
  1480. "id_token": "id_token_val",
  1481. })
  1482. server.binding.OIDC.oauth2Config = &mockOAuth2Config{
  1483. tokenSource: &mockTokenSource{},
  1484. authCodeURL: webOIDCRedirectPath,
  1485. token: token,
  1486. }
  1487. idToken := &oidc.IDToken{
  1488. Nonce: authReq.Nonce,
  1489. Expiry: time.Now().Add(5 * time.Minute),
  1490. }
  1491. setIDTokenClaims(idToken, []byte(`{"preferred_username":"admin","sid":"sid456"}`))
  1492. server.binding.OIDC.verifier = &mockOIDCVerifier{
  1493. err: nil,
  1494. token: idToken,
  1495. }
  1496. rr := httptest.NewRecorder()
  1497. r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil)
  1498. assert.NoError(t, err)
  1499. server.router.ServeHTTP(rr, r)
  1500. assert.Equal(t, http.StatusFound, rr.Code)
  1501. assert.Equal(t, webUsersPath, rr.Header().Get("Location"))
  1502. var tokenCookie string
  1503. for k := range oidcMgr.tokens {
  1504. tokenCookie = k
  1505. }
  1506. // we should be able to create admins without setting a password
  1507. if csrfTokenAuth == nil {
  1508. csrfTokenAuth = jwtauth.New(jwa.HS256.String(), util.GenerateRandomBytes(32), nil)
  1509. }
  1510. adminUsername := "testAdmin"
  1511. form := make(url.Values)
  1512. form.Set(csrfFormToken, createCSRFToken(""))
  1513. form.Set("username", adminUsername)
  1514. form.Set("password", "")
  1515. form.Set("status", "1")
  1516. form.Set("permissions", "*")
  1517. rr = httptest.NewRecorder()
  1518. r, err = http.NewRequest(http.MethodPost, webAdminPath, bytes.NewBuffer([]byte(form.Encode())))
  1519. assert.NoError(t, err)
  1520. r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie))
  1521. r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  1522. server.router.ServeHTTP(rr, r)
  1523. assert.Equal(t, http.StatusSeeOther, rr.Code)
  1524. _, err = dataprovider.AdminExists(adminUsername)
  1525. assert.NoError(t, err)
  1526. err = dataprovider.DeleteAdmin(adminUsername, "", "", "")
  1527. assert.NoError(t, err)
  1528. // login and password related routes are disabled
  1529. rr = httptest.NewRecorder()
  1530. r, err = http.NewRequest(http.MethodPost, webAdminLoginPath, nil)
  1531. assert.NoError(t, err)
  1532. server.router.ServeHTTP(rr, r)
  1533. assert.Equal(t, http.StatusMethodNotAllowed, rr.Code)
  1534. rr = httptest.NewRecorder()
  1535. r, err = http.NewRequest(http.MethodPost, webAdminTwoFactorPath, nil)
  1536. assert.NoError(t, err)
  1537. server.router.ServeHTTP(rr, r)
  1538. assert.Equal(t, http.StatusNotFound, rr.Code)
  1539. rr = httptest.NewRecorder()
  1540. r, err = http.NewRequest(http.MethodPost, webClientLoginPath, nil)
  1541. assert.NoError(t, err)
  1542. server.router.ServeHTTP(rr, r)
  1543. assert.Equal(t, http.StatusMethodNotAllowed, rr.Code)
  1544. rr = httptest.NewRecorder()
  1545. r, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, nil)
  1546. assert.NoError(t, err)
  1547. server.router.ServeHTTP(rr, r)
  1548. assert.Equal(t, http.StatusNotFound, rr.Code)
  1549. }
  1550. func TestDbOIDCManager(t *testing.T) {
  1551. if !isSharedProviderSupported() {
  1552. t.Skip("this test it is not available with this provider")
  1553. }
  1554. mgr := newOIDCManager(1)
  1555. pendingAuth := newOIDCPendingAuth(tokenAudienceWebAdmin)
  1556. mgr.addPendingAuth(pendingAuth)
  1557. authReq, err := mgr.getPendingAuth(pendingAuth.State)
  1558. assert.NoError(t, err)
  1559. assert.Equal(t, pendingAuth, authReq)
  1560. pendingAuth.IssuedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour))
  1561. mgr.addPendingAuth(pendingAuth)
  1562. _, err = mgr.getPendingAuth(pendingAuth.State)
  1563. if assert.Error(t, err) {
  1564. assert.Contains(t, err.Error(), "auth request is too old")
  1565. }
  1566. mgr.removePendingAuth(pendingAuth.State)
  1567. _, err = mgr.getPendingAuth(pendingAuth.State)
  1568. if assert.Error(t, err) {
  1569. assert.Contains(t, err.Error(), "unable to get the auth request for the specified state")
  1570. }
  1571. mgr.addPendingAuth(pendingAuth)
  1572. _, err = mgr.getPendingAuth(pendingAuth.State)
  1573. if assert.Error(t, err) {
  1574. assert.Contains(t, err.Error(), "auth request is too old")
  1575. }
  1576. mgr.cleanup()
  1577. _, err = mgr.getPendingAuth(pendingAuth.State)
  1578. if assert.Error(t, err) {
  1579. assert.Contains(t, err.Error(), "unable to get the auth request for the specified state")
  1580. }
  1581. token := oidcToken{
  1582. Cookie: xid.New().String(),
  1583. AccessToken: xid.New().String(),
  1584. TokenType: "Bearer",
  1585. RefreshToken: xid.New().String(),
  1586. ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-2 * time.Minute)),
  1587. SessionID: xid.New().String(),
  1588. IDToken: xid.New().String(),
  1589. Nonce: xid.New().String(),
  1590. Username: xid.New().String(),
  1591. Permissions: []string{dataprovider.PermAdminAny},
  1592. Role: "admin",
  1593. }
  1594. mgr.addToken(token)
  1595. tokenGet, err := mgr.getToken(token.Cookie)
  1596. assert.NoError(t, err)
  1597. assert.Greater(t, tokenGet.UsedAt, int64(0))
  1598. token.UsedAt = tokenGet.UsedAt
  1599. assert.Equal(t, token, tokenGet)
  1600. time.Sleep(100 * time.Millisecond)
  1601. mgr.updateTokenUsage(token)
  1602. // no change
  1603. tokenGet, err = mgr.getToken(token.Cookie)
  1604. assert.NoError(t, err)
  1605. assert.Equal(t, token.UsedAt, tokenGet.UsedAt)
  1606. tokenGet.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour))
  1607. tokenGet.RefreshToken = xid.New().String()
  1608. mgr.updateTokenUsage(tokenGet)
  1609. tokenGet, err = mgr.getToken(token.Cookie)
  1610. assert.NoError(t, err)
  1611. assert.NotEmpty(t, tokenGet.RefreshToken)
  1612. assert.NotEqual(t, token.RefreshToken, tokenGet.RefreshToken)
  1613. assert.Greater(t, tokenGet.UsedAt, token.UsedAt)
  1614. mgr.removeToken(token.Cookie)
  1615. tokenGet, err = mgr.getToken(token.Cookie)
  1616. if assert.Error(t, err) {
  1617. assert.Contains(t, err.Error(), "unable to get the token for the specified session")
  1618. }
  1619. // add an expired token
  1620. token.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour))
  1621. session := dataprovider.Session{
  1622. Key: token.Cookie,
  1623. Data: token,
  1624. Type: dataprovider.SessionTypeOIDCToken,
  1625. Timestamp: token.UsedAt + tokenDeleteInterval,
  1626. }
  1627. err = dataprovider.AddSharedSession(session)
  1628. assert.NoError(t, err)
  1629. _, err = mgr.getToken(token.Cookie)
  1630. if assert.Error(t, err) {
  1631. assert.Contains(t, err.Error(), "token is too old")
  1632. }
  1633. mgr.cleanup()
  1634. _, err = mgr.getToken(token.Cookie)
  1635. if assert.Error(t, err) {
  1636. assert.Contains(t, err.Error(), "unable to get the token for the specified session")
  1637. }
  1638. // adding a session without a key should fail
  1639. session.Key = ""
  1640. err = dataprovider.AddSharedSession(session)
  1641. if assert.Error(t, err) {
  1642. assert.Contains(t, err.Error(), "unable to save a session with an empty key")
  1643. }
  1644. session.Key = xid.New().String()
  1645. session.Type = 1000
  1646. err = dataprovider.AddSharedSession(session)
  1647. if assert.Error(t, err) {
  1648. assert.Contains(t, err.Error(), "invalid session type")
  1649. }
  1650. dbMgr, ok := mgr.(*dbOIDCManager)
  1651. if assert.True(t, ok) {
  1652. _, err = dbMgr.decodePendingAuthData(2)
  1653. assert.Error(t, err)
  1654. _, err = dbMgr.decodeTokenData(true)
  1655. assert.Error(t, err)
  1656. }
  1657. }
  1658. func getTestOIDCServer() *httpdServer {
  1659. return &httpdServer{
  1660. binding: Binding{
  1661. OIDC: OIDC{
  1662. ClientID: "sftpgo-client",
  1663. ClientSecret: "jRsmE0SWnuZjP7djBqNq0mrf8QN77j2c",
  1664. ConfigURL: fmt.Sprintf("http://%v/auth/realms/sftpgo", oidcMockAddr),
  1665. RedirectBaseURL: "http://127.0.0.1:8081/",
  1666. UsernameField: "preferred_username",
  1667. RoleField: "sftpgo_role",
  1668. ImplicitRoles: false,
  1669. Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
  1670. CustomFields: nil,
  1671. Debug: true,
  1672. },
  1673. },
  1674. enableWebAdmin: true,
  1675. enableWebClient: true,
  1676. }
  1677. }
  1678. func getPreLoginScriptContent(user dataprovider.User, nonJSONResponse bool) []byte {
  1679. content := []byte("#!/bin/sh\n\n")
  1680. if nonJSONResponse {
  1681. content = append(content, []byte("echo 'text response'\n")...)
  1682. return content
  1683. }
  1684. if len(user.Username) > 0 {
  1685. u, _ := json.Marshal(user)
  1686. content = append(content, []byte(fmt.Sprintf("echo '%v'\n", string(u)))...)
  1687. }
  1688. return content
  1689. }