sqlcommon.go 102 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714
  1. // Copyright (C) 2019-2022 Nicola Murino
  2. //
  3. // This program is free software: you can redistribute it and/or modify
  4. // it under the terms of the GNU Affero General Public License as published
  5. // by the Free Software Foundation, version 3.
  6. //
  7. // This program is distributed in the hope that it will be useful,
  8. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. // GNU Affero General Public License for more details.
  11. //
  12. // You should have received a copy of the GNU Affero General Public License
  13. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. package dataprovider
  15. import (
  16. "context"
  17. "crypto/x509"
  18. "database/sql"
  19. "encoding/json"
  20. "errors"
  21. "fmt"
  22. "runtime/debug"
  23. "strings"
  24. "time"
  25. "github.com/cockroachdb/cockroach-go/v2/crdb"
  26. "github.com/sftpgo/sdk"
  27. "github.com/drakkan/sftpgo/v2/internal/logger"
  28. "github.com/drakkan/sftpgo/v2/internal/util"
  29. "github.com/drakkan/sftpgo/v2/internal/vfs"
  30. )
  31. const (
  32. sqlDatabaseVersion = 25
  33. defaultSQLQueryTimeout = 10 * time.Second
  34. longSQLQueryTimeout = 60 * time.Second
  35. )
  36. var (
  37. errSQLFoldersAssociation = errors.New("unable to associate virtual folders to user")
  38. errSQLGroupsAssociation = errors.New("unable to associate groups to user")
  39. errSQLUsersAssociation = errors.New("unable to associate users to group")
  40. errSchemaVersionEmpty = errors.New("we can't determine schema version because the schema_migration table is empty. The SFTPGo database might be corrupted. Consider using the \"resetprovider\" sub-command")
  41. )
  42. type sqlQuerier interface {
  43. QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
  44. QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
  45. ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
  46. PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
  47. }
  48. type sqlScanner interface {
  49. Scan(dest ...any) error
  50. }
  51. func sqlReplaceAll(sql string) string {
  52. sql = strings.ReplaceAll(sql, "{{schema_version}}", sqlTableSchemaVersion)
  53. sql = strings.ReplaceAll(sql, "{{admins}}", sqlTableAdmins)
  54. sql = strings.ReplaceAll(sql, "{{folders}}", sqlTableFolders)
  55. sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers)
  56. sql = strings.ReplaceAll(sql, "{{groups}}", sqlTableGroups)
  57. sql = strings.ReplaceAll(sql, "{{users_folders_mapping}}", sqlTableUsersFoldersMapping)
  58. sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping)
  59. sql = strings.ReplaceAll(sql, "{{admins_groups_mapping}}", sqlTableAdminsGroupsMapping)
  60. sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping)
  61. sql = strings.ReplaceAll(sql, "{{api_keys}}", sqlTableAPIKeys)
  62. sql = strings.ReplaceAll(sql, "{{shares}}", sqlTableShares)
  63. sql = strings.ReplaceAll(sql, "{{defender_events}}", sqlTableDefenderEvents)
  64. sql = strings.ReplaceAll(sql, "{{defender_hosts}}", sqlTableDefenderHosts)
  65. sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers)
  66. sql = strings.ReplaceAll(sql, "{{shared_sessions}}", sqlTableSharedSessions)
  67. sql = strings.ReplaceAll(sql, "{{events_actions}}", sqlTableEventsActions)
  68. sql = strings.ReplaceAll(sql, "{{events_rules}}", sqlTableEventsRules)
  69. sql = strings.ReplaceAll(sql, "{{rules_actions_mapping}}", sqlTableRulesActionsMapping)
  70. sql = strings.ReplaceAll(sql, "{{tasks}}", sqlTableTasks)
  71. sql = strings.ReplaceAll(sql, "{{nodes}}", sqlTableNodes)
  72. sql = strings.ReplaceAll(sql, "{{roles}}", sqlTableRoles)
  73. sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
  74. return sql
  75. }
  76. func sqlCommonGetShareByID(shareID, username string, dbHandle sqlQuerier) (Share, error) {
  77. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  78. defer cancel()
  79. filterUser := username != ""
  80. q := getShareByIDQuery(filterUser)
  81. var row *sql.Row
  82. if filterUser {
  83. row = dbHandle.QueryRowContext(ctx, q, shareID, username)
  84. } else {
  85. row = dbHandle.QueryRowContext(ctx, q, shareID)
  86. }
  87. return getShareFromDbRow(row)
  88. }
  89. func sqlCommonAddShare(share *Share, dbHandle *sql.DB) error {
  90. err := share.validate()
  91. if err != nil {
  92. return err
  93. }
  94. user, err := provider.userExists(share.Username, "")
  95. if err != nil {
  96. return util.NewValidationError(fmt.Sprintf("unable to validate user %#v", share.Username))
  97. }
  98. paths, err := json.Marshal(share.Paths)
  99. if err != nil {
  100. return err
  101. }
  102. allowFrom := ""
  103. if len(share.AllowFrom) > 0 {
  104. res, err := json.Marshal(share.AllowFrom)
  105. if err == nil {
  106. allowFrom = string(res)
  107. }
  108. }
  109. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  110. defer cancel()
  111. q := getAddShareQuery()
  112. usedTokens := 0
  113. createdAt := util.GetTimeAsMsSinceEpoch(time.Now())
  114. updatedAt := createdAt
  115. lastUseAt := int64(0)
  116. if share.IsRestore {
  117. usedTokens = share.UsedTokens
  118. if share.CreatedAt > 0 {
  119. createdAt = share.CreatedAt
  120. }
  121. if share.UpdatedAt > 0 {
  122. updatedAt = share.UpdatedAt
  123. }
  124. lastUseAt = share.LastUseAt
  125. }
  126. _, err = dbHandle.ExecContext(ctx, q, share.ShareID, share.Name, share.Description, share.Scope,
  127. string(paths), createdAt, updatedAt, lastUseAt, share.ExpiresAt, share.Password,
  128. share.MaxTokens, usedTokens, allowFrom, user.ID)
  129. return err
  130. }
  131. func sqlCommonUpdateShare(share *Share, dbHandle *sql.DB) error {
  132. err := share.validate()
  133. if err != nil {
  134. return err
  135. }
  136. paths, err := json.Marshal(share.Paths)
  137. if err != nil {
  138. return err
  139. }
  140. allowFrom := ""
  141. if len(share.AllowFrom) > 0 {
  142. res, err := json.Marshal(share.AllowFrom)
  143. if err == nil {
  144. allowFrom = string(res)
  145. }
  146. }
  147. user, err := provider.userExists(share.Username, "")
  148. if err != nil {
  149. return util.NewValidationError(fmt.Sprintf("unable to validate user %#v", share.Username))
  150. }
  151. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  152. defer cancel()
  153. var q string
  154. if share.IsRestore {
  155. q = getUpdateShareRestoreQuery()
  156. } else {
  157. q = getUpdateShareQuery()
  158. }
  159. if share.IsRestore {
  160. if share.CreatedAt == 0 {
  161. share.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
  162. }
  163. if share.UpdatedAt == 0 {
  164. share.UpdatedAt = share.CreatedAt
  165. }
  166. _, err = dbHandle.ExecContext(ctx, q, share.Name, share.Description, share.Scope, string(paths),
  167. share.CreatedAt, share.UpdatedAt, share.LastUseAt, share.ExpiresAt, share.Password, share.MaxTokens,
  168. share.UsedTokens, allowFrom, user.ID, share.ShareID)
  169. } else {
  170. _, err = dbHandle.ExecContext(ctx, q, share.Name, share.Description, share.Scope, string(paths),
  171. util.GetTimeAsMsSinceEpoch(time.Now()), share.ExpiresAt, share.Password, share.MaxTokens,
  172. allowFrom, user.ID, share.ShareID)
  173. }
  174. return err
  175. }
  176. func sqlCommonDeleteShare(share Share, dbHandle *sql.DB) error {
  177. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  178. defer cancel()
  179. q := getDeleteShareQuery()
  180. res, err := dbHandle.ExecContext(ctx, q, share.ShareID)
  181. if err != nil {
  182. return err
  183. }
  184. return sqlCommonRequireRowAffected(res)
  185. }
  186. func sqlCommonGetShares(limit, offset int, order, username string, dbHandle sqlQuerier) ([]Share, error) {
  187. shares := make([]Share, 0, limit)
  188. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  189. defer cancel()
  190. q := getSharesQuery(order)
  191. rows, err := dbHandle.QueryContext(ctx, q, username, limit, offset)
  192. if err != nil {
  193. return shares, err
  194. }
  195. defer rows.Close()
  196. for rows.Next() {
  197. s, err := getShareFromDbRow(rows)
  198. if err != nil {
  199. return shares, err
  200. }
  201. s.HideConfidentialData()
  202. shares = append(shares, s)
  203. }
  204. return shares, rows.Err()
  205. }
  206. func sqlCommonDumpShares(dbHandle sqlQuerier) ([]Share, error) {
  207. shares := make([]Share, 0, 30)
  208. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  209. defer cancel()
  210. q := getDumpSharesQuery()
  211. rows, err := dbHandle.QueryContext(ctx, q)
  212. if err != nil {
  213. return shares, err
  214. }
  215. defer rows.Close()
  216. for rows.Next() {
  217. s, err := getShareFromDbRow(rows)
  218. if err != nil {
  219. return shares, err
  220. }
  221. shares = append(shares, s)
  222. }
  223. return shares, rows.Err()
  224. }
  225. func sqlCommonGetAPIKeyByID(keyID string, dbHandle sqlQuerier) (APIKey, error) {
  226. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  227. defer cancel()
  228. q := getAPIKeyByIDQuery()
  229. row := dbHandle.QueryRowContext(ctx, q, keyID)
  230. apiKey, err := getAPIKeyFromDbRow(row)
  231. if err != nil {
  232. return apiKey, err
  233. }
  234. return getAPIKeyWithRelatedFields(ctx, apiKey, dbHandle)
  235. }
  236. func sqlCommonAddAPIKey(apiKey *APIKey, dbHandle *sql.DB) error {
  237. err := apiKey.validate()
  238. if err != nil {
  239. return err
  240. }
  241. userID, adminID, err := sqlCommonGetAPIKeyRelatedIDs(apiKey)
  242. if err != nil {
  243. return err
  244. }
  245. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  246. defer cancel()
  247. q := getAddAPIKeyQuery()
  248. _, err = dbHandle.ExecContext(ctx, q, apiKey.KeyID, apiKey.Name, apiKey.Key, apiKey.Scope,
  249. util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(time.Now()), apiKey.LastUseAt,
  250. apiKey.ExpiresAt, apiKey.Description, userID, adminID)
  251. return err
  252. }
  253. func sqlCommonUpdateAPIKey(apiKey *APIKey, dbHandle *sql.DB) error {
  254. err := apiKey.validate()
  255. if err != nil {
  256. return err
  257. }
  258. userID, adminID, err := sqlCommonGetAPIKeyRelatedIDs(apiKey)
  259. if err != nil {
  260. return err
  261. }
  262. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  263. defer cancel()
  264. q := getUpdateAPIKeyQuery()
  265. _, err = dbHandle.ExecContext(ctx, q, apiKey.Name, apiKey.Scope, apiKey.ExpiresAt, userID, adminID,
  266. apiKey.Description, util.GetTimeAsMsSinceEpoch(time.Now()), apiKey.KeyID)
  267. return err
  268. }
  269. func sqlCommonDeleteAPIKey(apiKey APIKey, dbHandle *sql.DB) error {
  270. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  271. defer cancel()
  272. q := getDeleteAPIKeyQuery()
  273. res, err := dbHandle.ExecContext(ctx, q, apiKey.KeyID)
  274. if err != nil {
  275. return err
  276. }
  277. return sqlCommonRequireRowAffected(res)
  278. }
  279. func sqlCommonGetAPIKeys(limit, offset int, order string, dbHandle sqlQuerier) ([]APIKey, error) {
  280. apiKeys := make([]APIKey, 0, limit)
  281. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  282. defer cancel()
  283. q := getAPIKeysQuery(order)
  284. rows, err := dbHandle.QueryContext(ctx, q, limit, offset)
  285. if err != nil {
  286. return apiKeys, err
  287. }
  288. defer rows.Close()
  289. for rows.Next() {
  290. k, err := getAPIKeyFromDbRow(rows)
  291. if err != nil {
  292. return apiKeys, err
  293. }
  294. k.HideConfidentialData()
  295. apiKeys = append(apiKeys, k)
  296. }
  297. err = rows.Err()
  298. if err != nil {
  299. return apiKeys, err
  300. }
  301. apiKeys, err = getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeAdmin)
  302. if err != nil {
  303. return apiKeys, err
  304. }
  305. return getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeUser)
  306. }
  307. func sqlCommonDumpAPIKeys(dbHandle sqlQuerier) ([]APIKey, error) {
  308. apiKeys := make([]APIKey, 0, 30)
  309. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  310. defer cancel()
  311. q := getDumpAPIKeysQuery()
  312. rows, err := dbHandle.QueryContext(ctx, q)
  313. if err != nil {
  314. return apiKeys, err
  315. }
  316. defer rows.Close()
  317. for rows.Next() {
  318. k, err := getAPIKeyFromDbRow(rows)
  319. if err != nil {
  320. return apiKeys, err
  321. }
  322. apiKeys = append(apiKeys, k)
  323. }
  324. err = rows.Err()
  325. if err != nil {
  326. return apiKeys, err
  327. }
  328. apiKeys, err = getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeAdmin)
  329. if err != nil {
  330. return apiKeys, err
  331. }
  332. return getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeUser)
  333. }
  334. func sqlCommonGetAdminByUsername(username string, dbHandle sqlQuerier) (Admin, error) {
  335. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  336. defer cancel()
  337. q := getAdminByUsernameQuery()
  338. row := dbHandle.QueryRowContext(ctx, q, username)
  339. admin, err := getAdminFromDbRow(row)
  340. if err != nil {
  341. return admin, err
  342. }
  343. return getAdminWithGroups(ctx, admin, dbHandle)
  344. }
  345. func sqlCommonValidateAdminAndPass(username, password, ip string, dbHandle *sql.DB) (Admin, error) {
  346. admin, err := sqlCommonGetAdminByUsername(username, dbHandle)
  347. if err != nil {
  348. providerLog(logger.LevelWarn, "error authenticating admin %#v: %v", username, err)
  349. return admin, ErrInvalidCredentials
  350. }
  351. err = admin.checkUserAndPass(password, ip)
  352. return admin, err
  353. }
  354. func sqlCommonAddAdmin(admin *Admin, dbHandle *sql.DB) error {
  355. err := admin.validate()
  356. if err != nil {
  357. return err
  358. }
  359. perms, err := json.Marshal(admin.Permissions)
  360. if err != nil {
  361. return err
  362. }
  363. filters, err := json.Marshal(admin.Filters)
  364. if err != nil {
  365. return err
  366. }
  367. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  368. defer cancel()
  369. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  370. q := getAddAdminQuery(admin.Role)
  371. _, err = tx.ExecContext(ctx, q, admin.Username, admin.Password, admin.Status, admin.Email, string(perms),
  372. string(filters), admin.AdditionalInfo, admin.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
  373. util.GetTimeAsMsSinceEpoch(time.Now()), admin.Role)
  374. if err != nil {
  375. return err
  376. }
  377. return generateAdminGroupMapping(ctx, admin, tx)
  378. })
  379. }
  380. func sqlCommonUpdateAdmin(admin *Admin, dbHandle *sql.DB) error {
  381. err := admin.validate()
  382. if err != nil {
  383. return err
  384. }
  385. perms, err := json.Marshal(admin.Permissions)
  386. if err != nil {
  387. return err
  388. }
  389. filters, err := json.Marshal(admin.Filters)
  390. if err != nil {
  391. return err
  392. }
  393. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  394. defer cancel()
  395. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  396. q := getUpdateAdminQuery(admin.Role)
  397. _, err = tx.ExecContext(ctx, q, admin.Password, admin.Status, admin.Email, string(perms), string(filters),
  398. admin.AdditionalInfo, admin.Description, util.GetTimeAsMsSinceEpoch(time.Now()), admin.Role, admin.Username)
  399. if err != nil {
  400. return err
  401. }
  402. return generateAdminGroupMapping(ctx, admin, tx)
  403. })
  404. }
  405. func sqlCommonDeleteAdmin(admin Admin, dbHandle *sql.DB) error {
  406. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  407. defer cancel()
  408. q := getDeleteAdminQuery()
  409. res, err := dbHandle.ExecContext(ctx, q, admin.Username)
  410. if err != nil {
  411. return err
  412. }
  413. return sqlCommonRequireRowAffected(res)
  414. }
  415. func sqlCommonGetAdmins(limit, offset int, order string, dbHandle sqlQuerier) ([]Admin, error) {
  416. admins := make([]Admin, 0, limit)
  417. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  418. defer cancel()
  419. q := getAdminsQuery(order)
  420. rows, err := dbHandle.QueryContext(ctx, q, limit, offset)
  421. if err != nil {
  422. return admins, err
  423. }
  424. defer rows.Close()
  425. for rows.Next() {
  426. a, err := getAdminFromDbRow(rows)
  427. if err != nil {
  428. return admins, err
  429. }
  430. a.HideConfidentialData()
  431. admins = append(admins, a)
  432. }
  433. err = rows.Err()
  434. if err != nil {
  435. return admins, err
  436. }
  437. return getAdminsWithGroups(ctx, admins, dbHandle)
  438. }
  439. func sqlCommonDumpAdmins(dbHandle sqlQuerier) ([]Admin, error) {
  440. admins := make([]Admin, 0, 30)
  441. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  442. defer cancel()
  443. q := getDumpAdminsQuery()
  444. rows, err := dbHandle.QueryContext(ctx, q)
  445. if err != nil {
  446. return admins, err
  447. }
  448. defer rows.Close()
  449. for rows.Next() {
  450. a, err := getAdminFromDbRow(rows)
  451. if err != nil {
  452. return admins, err
  453. }
  454. admins = append(admins, a)
  455. }
  456. err = rows.Err()
  457. if err != nil {
  458. return admins, err
  459. }
  460. return getAdminsWithGroups(ctx, admins, dbHandle)
  461. }
  462. func sqlCommonGetRoleByName(name string, dbHandle sqlQuerier) (Role, error) {
  463. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  464. defer cancel()
  465. q := getRoleByNameQuery()
  466. row := dbHandle.QueryRowContext(ctx, q, name)
  467. role, err := getRoleFromDbRow(row)
  468. if err != nil {
  469. return role, err
  470. }
  471. role, err = getRoleWithUsers(ctx, role, dbHandle)
  472. if err != nil {
  473. return role, err
  474. }
  475. return getRoleWithAdmins(ctx, role, dbHandle)
  476. }
  477. func sqlCommonDumpRoles(dbHandle sqlQuerier) ([]Role, error) {
  478. roles := make([]Role, 0, 10)
  479. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  480. defer cancel()
  481. q := getDumpRolesQuery()
  482. rows, err := dbHandle.QueryContext(ctx, q)
  483. if err != nil {
  484. return roles, err
  485. }
  486. defer rows.Close()
  487. for rows.Next() {
  488. role, err := getRoleFromDbRow(rows)
  489. if err != nil {
  490. return roles, err
  491. }
  492. roles = append(roles, role)
  493. }
  494. return roles, rows.Err()
  495. }
  496. func sqlCommonGetRoles(limit int, offset int, order string, minimal bool, dbHandle sqlQuerier) ([]Role, error) {
  497. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  498. defer cancel()
  499. q := getRolesQuery(order, minimal)
  500. roles := make([]Role, 0, limit)
  501. rows, err := dbHandle.QueryContext(ctx, q, limit, offset)
  502. if err != nil {
  503. return roles, err
  504. }
  505. defer rows.Close()
  506. for rows.Next() {
  507. var role Role
  508. if minimal {
  509. err = rows.Scan(&role.ID, &role.Name)
  510. } else {
  511. role, err = getRoleFromDbRow(rows)
  512. }
  513. if err != nil {
  514. return roles, err
  515. }
  516. roles = append(roles, role)
  517. }
  518. err = rows.Err()
  519. if err != nil {
  520. return roles, err
  521. }
  522. if minimal {
  523. return roles, nil
  524. }
  525. roles, err = getRolesWithUsers(ctx, roles, dbHandle)
  526. if err != nil {
  527. return roles, err
  528. }
  529. return getRolesWithAdmins(ctx, roles, dbHandle)
  530. }
  531. func sqlCommonAddRole(role *Role, dbHandle *sql.DB) error {
  532. if err := role.validate(); err != nil {
  533. return err
  534. }
  535. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  536. defer cancel()
  537. q := getAddRoleQuery()
  538. _, err := dbHandle.ExecContext(ctx, q, role.Name, role.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
  539. util.GetTimeAsMsSinceEpoch(time.Now()))
  540. return err
  541. }
  542. func sqlCommonUpdateRole(role *Role, dbHandle *sql.DB) error {
  543. if err := role.validate(); err != nil {
  544. return err
  545. }
  546. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  547. defer cancel()
  548. q := getUpdateRoleQuery()
  549. _, err := dbHandle.ExecContext(ctx, q, role.Description, util.GetTimeAsMsSinceEpoch(time.Now()), role.Name)
  550. return err
  551. }
  552. func sqlCommonDeleteRole(role Role, dbHandle *sql.DB) error {
  553. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  554. defer cancel()
  555. q := getDeleteRoleQuery()
  556. res, err := dbHandle.ExecContext(ctx, q, role.Name)
  557. if err != nil {
  558. return err
  559. }
  560. return sqlCommonRequireRowAffected(res)
  561. }
  562. func sqlCommonGetGroupByName(name string, dbHandle sqlQuerier) (Group, error) {
  563. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  564. defer cancel()
  565. q := getGroupByNameQuery()
  566. row := dbHandle.QueryRowContext(ctx, q, name)
  567. group, err := getGroupFromDbRow(row)
  568. if err != nil {
  569. return group, err
  570. }
  571. group, err = getGroupWithVirtualFolders(ctx, group, dbHandle)
  572. if err != nil {
  573. return group, err
  574. }
  575. group, err = getGroupWithUsers(ctx, group, dbHandle)
  576. if err != nil {
  577. return group, err
  578. }
  579. return getGroupWithAdmins(ctx, group, dbHandle)
  580. }
  581. func sqlCommonDumpGroups(dbHandle sqlQuerier) ([]Group, error) {
  582. groups := make([]Group, 0, 50)
  583. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  584. defer cancel()
  585. q := getDumpGroupsQuery()
  586. rows, err := dbHandle.QueryContext(ctx, q)
  587. if err != nil {
  588. return groups, err
  589. }
  590. defer rows.Close()
  591. for rows.Next() {
  592. group, err := getGroupFromDbRow(rows)
  593. if err != nil {
  594. return groups, err
  595. }
  596. groups = append(groups, group)
  597. }
  598. err = rows.Err()
  599. if err != nil {
  600. return groups, err
  601. }
  602. return getGroupsWithVirtualFolders(ctx, groups, dbHandle)
  603. }
  604. func sqlCommonGetUsersInGroups(names []string, dbHandle sqlQuerier) ([]string, error) {
  605. if len(names) == 0 {
  606. return nil, nil
  607. }
  608. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  609. defer cancel()
  610. q := getUsersInGroupsQuery(len(names))
  611. args := make([]any, 0, len(names))
  612. for _, name := range names {
  613. args = append(args, name)
  614. }
  615. usernames := make([]string, 0, len(names))
  616. rows, err := dbHandle.QueryContext(ctx, q, args...)
  617. if err != nil {
  618. return nil, err
  619. }
  620. defer rows.Close()
  621. for rows.Next() {
  622. var username string
  623. err = rows.Scan(&username)
  624. if err != nil {
  625. return usernames, err
  626. }
  627. usernames = append(usernames, username)
  628. }
  629. return usernames, rows.Err()
  630. }
  631. func sqlCommonGetGroupsWithNames(names []string, dbHandle sqlQuerier) ([]Group, error) {
  632. if len(names) == 0 {
  633. return nil, nil
  634. }
  635. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  636. defer cancel()
  637. q := getGroupsWithNamesQuery(len(names))
  638. args := make([]any, 0, len(names))
  639. for _, name := range names {
  640. args = append(args, name)
  641. }
  642. groups := make([]Group, 0, len(names))
  643. rows, err := dbHandle.QueryContext(ctx, q, args...)
  644. if err != nil {
  645. return groups, err
  646. }
  647. defer rows.Close()
  648. for rows.Next() {
  649. group, err := getGroupFromDbRow(rows)
  650. if err != nil {
  651. return groups, err
  652. }
  653. groups = append(groups, group)
  654. }
  655. err = rows.Err()
  656. if err != nil {
  657. return groups, err
  658. }
  659. return getGroupsWithVirtualFolders(ctx, groups, dbHandle)
  660. }
  661. func sqlCommonGetGroups(limit int, offset int, order string, minimal bool, dbHandle sqlQuerier) ([]Group, error) {
  662. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  663. defer cancel()
  664. q := getGroupsQuery(order, minimal)
  665. groups := make([]Group, 0, limit)
  666. rows, err := dbHandle.QueryContext(ctx, q, limit, offset)
  667. if err != nil {
  668. return groups, err
  669. }
  670. defer rows.Close()
  671. for rows.Next() {
  672. var group Group
  673. if minimal {
  674. err = rows.Scan(&group.ID, &group.Name)
  675. } else {
  676. group, err = getGroupFromDbRow(rows)
  677. }
  678. if err != nil {
  679. return groups, err
  680. }
  681. groups = append(groups, group)
  682. }
  683. err = rows.Err()
  684. if err != nil {
  685. return groups, err
  686. }
  687. if minimal {
  688. return groups, nil
  689. }
  690. groups, err = getGroupsWithVirtualFolders(ctx, groups, dbHandle)
  691. if err != nil {
  692. return groups, err
  693. }
  694. groups, err = getGroupsWithUsers(ctx, groups, dbHandle)
  695. if err != nil {
  696. return groups, err
  697. }
  698. groups, err = getGroupsWithAdmins(ctx, groups, dbHandle)
  699. if err != nil {
  700. return groups, err
  701. }
  702. for idx := range groups {
  703. groups[idx].PrepareForRendering()
  704. }
  705. return groups, nil
  706. }
  707. func sqlCommonAddGroup(group *Group, dbHandle *sql.DB) error {
  708. if err := group.validate(); err != nil {
  709. return err
  710. }
  711. settings, err := json.Marshal(group.UserSettings)
  712. if err != nil {
  713. return err
  714. }
  715. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  716. defer cancel()
  717. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  718. q := getAddGroupQuery()
  719. _, err := tx.ExecContext(ctx, q, group.Name, group.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
  720. util.GetTimeAsMsSinceEpoch(time.Now()), string(settings))
  721. if err != nil {
  722. return err
  723. }
  724. return generateGroupVirtualFoldersMapping(ctx, group, tx)
  725. })
  726. }
  727. func sqlCommonUpdateGroup(group *Group, dbHandle *sql.DB) error {
  728. if err := group.validate(); err != nil {
  729. return err
  730. }
  731. settings, err := json.Marshal(group.UserSettings)
  732. if err != nil {
  733. return err
  734. }
  735. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  736. defer cancel()
  737. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  738. q := getUpdateGroupQuery()
  739. _, err := tx.ExecContext(ctx, q, group.Description, settings, util.GetTimeAsMsSinceEpoch(time.Now()), group.Name)
  740. if err != nil {
  741. return err
  742. }
  743. return generateGroupVirtualFoldersMapping(ctx, group, tx)
  744. })
  745. }
  746. func sqlCommonDeleteGroup(group Group, dbHandle *sql.DB) error {
  747. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  748. defer cancel()
  749. q := getDeleteGroupQuery()
  750. res, err := dbHandle.ExecContext(ctx, q, group.Name)
  751. if err != nil {
  752. return err
  753. }
  754. return sqlCommonRequireRowAffected(res)
  755. }
  756. func sqlCommonGetUserByUsername(username, role string, dbHandle sqlQuerier) (User, error) {
  757. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  758. defer cancel()
  759. q := getUserByUsernameQuery(role)
  760. args := []any{username}
  761. if role != "" {
  762. args = append(args, role)
  763. }
  764. row := dbHandle.QueryRowContext(ctx, q, args...)
  765. user, err := getUserFromDbRow(row)
  766. if err != nil {
  767. return user, err
  768. }
  769. user, err = getUserWithVirtualFolders(ctx, user, dbHandle)
  770. if err != nil {
  771. return user, err
  772. }
  773. return getUserWithGroups(ctx, user, dbHandle)
  774. }
  775. func sqlCommonValidateUserAndPass(username, password, ip, protocol string, dbHandle *sql.DB) (User, error) {
  776. user, err := sqlCommonGetUserByUsername(username, "", dbHandle)
  777. if err != nil {
  778. providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
  779. return user, err
  780. }
  781. return checkUserAndPass(&user, password, ip, protocol)
  782. }
  783. func sqlCommonValidateUserAndTLSCertificate(username, protocol string, tlsCert *x509.Certificate, dbHandle *sql.DB) (User, error) {
  784. var user User
  785. if tlsCert == nil {
  786. return user, errors.New("TLS certificate cannot be null or empty")
  787. }
  788. user, err := sqlCommonGetUserByUsername(username, "", dbHandle)
  789. if err != nil {
  790. providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
  791. return user, err
  792. }
  793. return checkUserAndTLSCertificate(&user, protocol, tlsCert)
  794. }
  795. func sqlCommonValidateUserAndPubKey(username string, pubKey []byte, isSSHCert bool, dbHandle *sql.DB) (User, string, error) {
  796. var user User
  797. if len(pubKey) == 0 {
  798. return user, "", errors.New("credentials cannot be null or empty")
  799. }
  800. user, err := sqlCommonGetUserByUsername(username, "", dbHandle)
  801. if err != nil {
  802. providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
  803. return user, "", err
  804. }
  805. return checkUserAndPubKey(&user, pubKey, isSSHCert)
  806. }
  807. func sqlCommonCheckAvailability(dbHandle *sql.DB) (err error) {
  808. defer func() {
  809. if r := recover(); r != nil {
  810. providerLog(logger.LevelError, "panic in check provider availability, stack trace: %v", string(debug.Stack()))
  811. err = errors.New("unable to check provider status")
  812. }
  813. }()
  814. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  815. defer cancel()
  816. err = dbHandle.PingContext(ctx)
  817. return
  818. }
  819. func sqlCommonUpdateTransferQuota(username string, uploadSize, downloadSize int64, reset bool, dbHandle *sql.DB) error {
  820. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  821. defer cancel()
  822. q := getUpdateTransferQuotaQuery(reset)
  823. _, err := dbHandle.ExecContext(ctx, q, uploadSize, downloadSize, util.GetTimeAsMsSinceEpoch(time.Now()), username)
  824. if err == nil {
  825. providerLog(logger.LevelDebug, "transfer quota updated for user %#v, ul increment: %v dl increment: %v is reset? %v",
  826. username, uploadSize, downloadSize, reset)
  827. } else {
  828. providerLog(logger.LevelError, "error updating quota for user %#v: %v", username, err)
  829. }
  830. return err
  831. }
  832. func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bool, dbHandle *sql.DB) error {
  833. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  834. defer cancel()
  835. q := getUpdateQuotaQuery(reset)
  836. _, err := dbHandle.ExecContext(ctx, q, sizeAdd, filesAdd, util.GetTimeAsMsSinceEpoch(time.Now()), username)
  837. if err == nil {
  838. providerLog(logger.LevelDebug, "quota updated for user %#v, files increment: %v size increment: %v is reset? %v",
  839. username, filesAdd, sizeAdd, reset)
  840. } else {
  841. providerLog(logger.LevelError, "error updating quota for user %#v: %v", username, err)
  842. }
  843. return err
  844. }
  845. func sqlCommonGetUsedQuota(username string, dbHandle *sql.DB) (int, int64, int64, int64, error) {
  846. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  847. defer cancel()
  848. q := getQuotaQuery()
  849. var usedFiles int
  850. var usedSize, usedUploadSize, usedDownloadSize int64
  851. err := dbHandle.QueryRowContext(ctx, q, username).Scan(&usedSize, &usedFiles, &usedUploadSize, &usedDownloadSize)
  852. if err != nil {
  853. providerLog(logger.LevelError, "error getting quota for user: %v, error: %v", username, err)
  854. return 0, 0, 0, 0, err
  855. }
  856. return usedFiles, usedSize, usedUploadSize, usedDownloadSize, err
  857. }
  858. func sqlCommonUpdateShareLastUse(shareID string, numTokens int, dbHandle *sql.DB) error {
  859. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  860. defer cancel()
  861. q := getUpdateShareLastUseQuery()
  862. _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), numTokens, shareID)
  863. if err == nil {
  864. providerLog(logger.LevelDebug, "last use updated for shared object %#v", shareID)
  865. } else {
  866. providerLog(logger.LevelWarn, "error updating last use for shared object %#v: %v", shareID, err)
  867. }
  868. return err
  869. }
  870. func sqlCommonUpdateAPIKeyLastUse(keyID string, dbHandle *sql.DB) error {
  871. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  872. defer cancel()
  873. q := getUpdateAPIKeyLastUseQuery()
  874. _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), keyID)
  875. if err == nil {
  876. providerLog(logger.LevelDebug, "last use updated for key %#v", keyID)
  877. } else {
  878. providerLog(logger.LevelWarn, "error updating last use for key %#v: %v", keyID, err)
  879. }
  880. return err
  881. }
  882. func sqlCommonUpdateAdminLastLogin(username string, dbHandle *sql.DB) error {
  883. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  884. defer cancel()
  885. q := getUpdateAdminLastLoginQuery()
  886. _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username)
  887. if err == nil {
  888. providerLog(logger.LevelDebug, "last login updated for admin %#v", username)
  889. } else {
  890. providerLog(logger.LevelWarn, "error updating last login for admin %#v: %v", username, err)
  891. }
  892. return err
  893. }
  894. func sqlCommonSetUpdatedAt(username string, dbHandle *sql.DB) {
  895. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  896. defer cancel()
  897. q := getSetUpdateAtQuery()
  898. _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username)
  899. if err == nil {
  900. providerLog(logger.LevelDebug, "updated_at set for user %#v", username)
  901. } else {
  902. providerLog(logger.LevelWarn, "error setting updated_at for user %#v: %v", username, err)
  903. }
  904. }
  905. func sqlCommonSetFirstDownloadTimestamp(username string, dbHandle *sql.DB) error {
  906. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  907. defer cancel()
  908. q := getSetFirstDownloadQuery()
  909. res, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username)
  910. if err != nil {
  911. return err
  912. }
  913. return sqlCommonRequireRowAffected(res)
  914. }
  915. func sqlCommonSetFirstUploadTimestamp(username string, dbHandle *sql.DB) error {
  916. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  917. defer cancel()
  918. q := getSetFirstUploadQuery()
  919. res, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username)
  920. if err != nil {
  921. return err
  922. }
  923. return sqlCommonRequireRowAffected(res)
  924. }
  925. func sqlCommonUpdateLastLogin(username string, dbHandle *sql.DB) error {
  926. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  927. defer cancel()
  928. q := getUpdateLastLoginQuery()
  929. _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username)
  930. if err == nil {
  931. providerLog(logger.LevelDebug, "last login updated for user %#v", username)
  932. } else {
  933. providerLog(logger.LevelWarn, "error updating last login for user %#v: %v", username, err)
  934. }
  935. return err
  936. }
  937. func sqlCommonAddUser(user *User, dbHandle *sql.DB) error {
  938. err := ValidateUser(user)
  939. if err != nil {
  940. return err
  941. }
  942. permissions, err := user.GetPermissionsAsJSON()
  943. if err != nil {
  944. return err
  945. }
  946. publicKeys, err := user.GetPublicKeysAsJSON()
  947. if err != nil {
  948. return err
  949. }
  950. filters, err := user.GetFiltersAsJSON()
  951. if err != nil {
  952. return err
  953. }
  954. fsConfig, err := user.GetFsConfigAsJSON()
  955. if err != nil {
  956. return err
  957. }
  958. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  959. defer cancel()
  960. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  961. if config.IsShared == 1 {
  962. _, err := tx.ExecContext(ctx, getRemoveSoftDeletedUserQuery(), user.Username)
  963. if err != nil {
  964. return err
  965. }
  966. }
  967. q := getAddUserQuery(user.Role)
  968. _, err := tx.ExecContext(ctx, q, user.Username, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID,
  969. user.MaxSessions, user.QuotaSize, user.QuotaFiles, string(permissions), user.UploadBandwidth,
  970. user.DownloadBandwidth, user.Status, user.ExpirationDate, string(filters), string(fsConfig), user.AdditionalInfo,
  971. user.Description, user.Email, util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(time.Now()),
  972. user.UploadDataTransfer, user.DownloadDataTransfer, user.TotalDataTransfer, user.Role, user.LastPasswordChange)
  973. if err != nil {
  974. return err
  975. }
  976. if err := generateUserVirtualFoldersMapping(ctx, user, tx); err != nil {
  977. return err
  978. }
  979. return generateUserGroupMapping(ctx, user, tx)
  980. })
  981. }
  982. func sqlCommonUpdateUserPassword(username, password string, dbHandle *sql.DB) error {
  983. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  984. defer cancel()
  985. q := getUpdateUserPasswordQuery()
  986. _, err := dbHandle.ExecContext(ctx, q, password, username)
  987. return err
  988. }
  989. func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error {
  990. err := ValidateUser(user)
  991. if err != nil {
  992. return err
  993. }
  994. permissions, err := user.GetPermissionsAsJSON()
  995. if err != nil {
  996. return err
  997. }
  998. publicKeys, err := user.GetPublicKeysAsJSON()
  999. if err != nil {
  1000. return err
  1001. }
  1002. filters, err := user.GetFiltersAsJSON()
  1003. if err != nil {
  1004. return err
  1005. }
  1006. fsConfig, err := user.GetFsConfigAsJSON()
  1007. if err != nil {
  1008. return err
  1009. }
  1010. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1011. defer cancel()
  1012. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  1013. q := getUpdateUserQuery(user.Role)
  1014. _, err := tx.ExecContext(ctx, q, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions,
  1015. user.QuotaSize, user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status,
  1016. user.ExpirationDate, string(filters), string(fsConfig), user.AdditionalInfo, user.Description, user.Email,
  1017. util.GetTimeAsMsSinceEpoch(time.Now()), user.UploadDataTransfer, user.DownloadDataTransfer, user.TotalDataTransfer,
  1018. user.Role, user.LastPasswordChange, user.ID)
  1019. if err != nil {
  1020. return err
  1021. }
  1022. if err := generateUserVirtualFoldersMapping(ctx, user, tx); err != nil {
  1023. return err
  1024. }
  1025. return generateUserGroupMapping(ctx, user, tx)
  1026. })
  1027. }
  1028. func sqlCommonDeleteUser(user User, softDelete bool, dbHandle *sql.DB) error {
  1029. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1030. defer cancel()
  1031. q := getDeleteUserQuery(softDelete)
  1032. if softDelete {
  1033. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  1034. if err := sqlCommonClearUserFolderMapping(ctx, &user, tx); err != nil {
  1035. return err
  1036. }
  1037. if err := sqlCommonClearUserGroupMapping(ctx, &user, tx); err != nil {
  1038. return err
  1039. }
  1040. ts := util.GetTimeAsMsSinceEpoch(time.Now())
  1041. res, err := tx.ExecContext(ctx, q, ts, ts, user.Username)
  1042. if err != nil {
  1043. return err
  1044. }
  1045. return sqlCommonRequireRowAffected(res)
  1046. })
  1047. }
  1048. res, err := dbHandle.ExecContext(ctx, q, user.ID)
  1049. if err != nil {
  1050. return err
  1051. }
  1052. return sqlCommonRequireRowAffected(res)
  1053. }
  1054. func sqlCommonDumpUsers(dbHandle sqlQuerier) ([]User, error) {
  1055. users := make([]User, 0, 100)
  1056. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  1057. defer cancel()
  1058. q := getDumpUsersQuery()
  1059. rows, err := dbHandle.QueryContext(ctx, q)
  1060. if err != nil {
  1061. return users, err
  1062. }
  1063. defer rows.Close()
  1064. for rows.Next() {
  1065. u, err := getUserFromDbRow(rows)
  1066. if err != nil {
  1067. return users, err
  1068. }
  1069. users = append(users, u)
  1070. }
  1071. err = rows.Err()
  1072. if err != nil {
  1073. return users, err
  1074. }
  1075. users, err = getUsersWithVirtualFolders(ctx, users, dbHandle)
  1076. if err != nil {
  1077. return users, err
  1078. }
  1079. return getUsersWithGroups(ctx, users, dbHandle)
  1080. }
  1081. func sqlCommonGetRecentlyUpdatedUsers(after int64, dbHandle sqlQuerier) ([]User, error) {
  1082. users := make([]User, 0, 10)
  1083. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1084. defer cancel()
  1085. q := getRecentlyUpdatedUsersQuery()
  1086. rows, err := dbHandle.QueryContext(ctx, q, after)
  1087. if err != nil {
  1088. return users, err
  1089. }
  1090. defer rows.Close()
  1091. for rows.Next() {
  1092. u, err := getUserFromDbRow(rows)
  1093. if err != nil {
  1094. return users, err
  1095. }
  1096. users = append(users, u)
  1097. }
  1098. err = rows.Err()
  1099. if err != nil {
  1100. return users, err
  1101. }
  1102. users, err = getUsersWithVirtualFolders(ctx, users, dbHandle)
  1103. if err != nil {
  1104. return users, err
  1105. }
  1106. users, err = getUsersWithGroups(ctx, users, dbHandle)
  1107. if err != nil {
  1108. return users, err
  1109. }
  1110. var groupNames []string
  1111. for _, u := range users {
  1112. for _, g := range u.Groups {
  1113. groupNames = append(groupNames, g.Name)
  1114. }
  1115. }
  1116. groupNames = util.RemoveDuplicates(groupNames, false)
  1117. groups, err := sqlCommonGetGroupsWithNames(groupNames, dbHandle)
  1118. if err != nil {
  1119. return users, err
  1120. }
  1121. if len(groups) == 0 {
  1122. return users, nil
  1123. }
  1124. groupsMapping := make(map[string]Group)
  1125. for idx := range groups {
  1126. groupsMapping[groups[idx].Name] = groups[idx]
  1127. }
  1128. for idx := range users {
  1129. ref := &users[idx]
  1130. ref.applyGroupSettings(groupsMapping)
  1131. }
  1132. return users, nil
  1133. }
  1134. func sqlCommonGetUsersForQuotaCheck(toFetch map[string]bool, dbHandle sqlQuerier) ([]User, error) {
  1135. users := make([]User, 0, 30)
  1136. usernames := make([]string, 0, len(toFetch))
  1137. for k := range toFetch {
  1138. usernames = append(usernames, k)
  1139. }
  1140. maxUsers := 30
  1141. for len(usernames) > 0 {
  1142. if maxUsers > len(usernames) {
  1143. maxUsers = len(usernames)
  1144. }
  1145. usersRange, err := sqlCommonGetUsersRangeForQuotaCheck(usernames[:maxUsers], dbHandle)
  1146. if err != nil {
  1147. return users, err
  1148. }
  1149. users = append(users, usersRange...)
  1150. usernames = usernames[maxUsers:]
  1151. }
  1152. var usersWithFolders []User
  1153. validIdx := 0
  1154. for _, user := range users {
  1155. if toFetch[user.Username] {
  1156. usersWithFolders = append(usersWithFolders, user)
  1157. } else {
  1158. users[validIdx] = user
  1159. validIdx++
  1160. }
  1161. }
  1162. users = users[:validIdx]
  1163. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1164. defer cancel()
  1165. usersWithFolders, err := getUsersWithVirtualFolders(ctx, usersWithFolders, dbHandle)
  1166. if err != nil {
  1167. return users, err
  1168. }
  1169. users = append(users, usersWithFolders...)
  1170. users, err = getUsersWithGroups(ctx, users, dbHandle)
  1171. if err != nil {
  1172. return users, err
  1173. }
  1174. var groupNames []string
  1175. for _, u := range users {
  1176. for _, g := range u.Groups {
  1177. groupNames = append(groupNames, g.Name)
  1178. }
  1179. }
  1180. groupNames = util.RemoveDuplicates(groupNames, false)
  1181. if len(groupNames) == 0 {
  1182. return users, nil
  1183. }
  1184. groups, err := sqlCommonGetGroupsWithNames(groupNames, dbHandle)
  1185. if err != nil {
  1186. return users, err
  1187. }
  1188. groupsMapping := make(map[string]Group)
  1189. for idx := range groups {
  1190. groupsMapping[groups[idx].Name] = groups[idx]
  1191. }
  1192. for idx := range users {
  1193. ref := &users[idx]
  1194. ref.applyGroupSettings(groupsMapping)
  1195. }
  1196. return users, nil
  1197. }
  1198. func sqlCommonGetUsersRangeForQuotaCheck(usernames []string, dbHandle sqlQuerier) ([]User, error) {
  1199. users := make([]User, 0, len(usernames))
  1200. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1201. defer cancel()
  1202. q := getUsersForQuotaCheckQuery(len(usernames))
  1203. queryArgs := make([]any, 0, len(usernames))
  1204. for idx := range usernames {
  1205. queryArgs = append(queryArgs, usernames[idx])
  1206. }
  1207. rows, err := dbHandle.QueryContext(ctx, q, queryArgs...)
  1208. if err != nil {
  1209. return nil, err
  1210. }
  1211. defer rows.Close()
  1212. for rows.Next() {
  1213. var user User
  1214. var filters []byte
  1215. err = rows.Scan(&user.ID, &user.Username, &user.QuotaSize, &user.UsedQuotaSize, &user.TotalDataTransfer,
  1216. &user.UploadDataTransfer, &user.DownloadDataTransfer, &user.UsedUploadDataTransfer,
  1217. &user.UsedDownloadDataTransfer, &filters)
  1218. if err != nil {
  1219. return users, err
  1220. }
  1221. var userFilters UserFilters
  1222. err = json.Unmarshal(filters, &userFilters)
  1223. if err == nil {
  1224. user.Filters = userFilters
  1225. }
  1226. users = append(users, user)
  1227. }
  1228. return users, rows.Err()
  1229. }
  1230. func sqlCommonAddActiveTransfer(transfer ActiveTransfer, dbHandle *sql.DB) error {
  1231. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1232. defer cancel()
  1233. q := getAddActiveTransferQuery()
  1234. now := util.GetTimeAsMsSinceEpoch(time.Now())
  1235. _, err := dbHandle.ExecContext(ctx, q, transfer.ID, transfer.ConnID, transfer.Type, transfer.Username,
  1236. transfer.FolderName, transfer.IP, transfer.TruncatedSize, transfer.CurrentULSize, transfer.CurrentDLSize,
  1237. now, now)
  1238. return err
  1239. }
  1240. func sqlCommonUpdateActiveTransferSizes(ulSize, dlSize, transferID int64, connectionID string, dbHandle *sql.DB) error {
  1241. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1242. defer cancel()
  1243. q := getUpdateActiveTransferSizesQuery()
  1244. _, err := dbHandle.ExecContext(ctx, q, ulSize, dlSize, util.GetTimeAsMsSinceEpoch(time.Now()), connectionID, transferID)
  1245. return err
  1246. }
  1247. func sqlCommonRemoveActiveTransfer(transferID int64, connectionID string, dbHandle *sql.DB) error {
  1248. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1249. defer cancel()
  1250. q := getRemoveActiveTransferQuery()
  1251. _, err := dbHandle.ExecContext(ctx, q, connectionID, transferID)
  1252. return err
  1253. }
  1254. func sqlCommonCleanupActiveTransfers(before time.Time, dbHandle *sql.DB) error {
  1255. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1256. defer cancel()
  1257. q := getCleanupActiveTransfersQuery()
  1258. _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(before))
  1259. return err
  1260. }
  1261. func sqlCommonGetActiveTransfers(from time.Time, dbHandle sqlQuerier) ([]ActiveTransfer, error) {
  1262. transfers := make([]ActiveTransfer, 0, 30)
  1263. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  1264. defer cancel()
  1265. q := getActiveTransfersQuery()
  1266. rows, err := dbHandle.QueryContext(ctx, q, util.GetTimeAsMsSinceEpoch(from))
  1267. if err != nil {
  1268. return nil, err
  1269. }
  1270. defer rows.Close()
  1271. for rows.Next() {
  1272. var transfer ActiveTransfer
  1273. var folderName sql.NullString
  1274. err = rows.Scan(&transfer.ID, &transfer.ConnID, &transfer.Type, &transfer.Username, &folderName, &transfer.IP,
  1275. &transfer.TruncatedSize, &transfer.CurrentULSize, &transfer.CurrentDLSize, &transfer.CreatedAt,
  1276. &transfer.UpdatedAt)
  1277. if err != nil {
  1278. return transfers, err
  1279. }
  1280. if folderName.Valid {
  1281. transfer.FolderName = folderName.String
  1282. }
  1283. transfers = append(transfers, transfer)
  1284. }
  1285. return transfers, rows.Err()
  1286. }
  1287. func sqlCommonGetUsers(limit int, offset int, order, role string, dbHandle sqlQuerier) ([]User, error) {
  1288. users := make([]User, 0, limit)
  1289. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1290. defer cancel()
  1291. q := getUsersQuery(order, role)
  1292. var args []any
  1293. if role == "" {
  1294. args = append(args, limit, offset)
  1295. } else {
  1296. args = append(args, role, limit, offset)
  1297. }
  1298. rows, err := dbHandle.QueryContext(ctx, q, args...)
  1299. if err != nil {
  1300. return users, err
  1301. }
  1302. defer rows.Close()
  1303. for rows.Next() {
  1304. u, err := getUserFromDbRow(rows)
  1305. if err != nil {
  1306. return users, err
  1307. }
  1308. users = append(users, u)
  1309. }
  1310. err = rows.Err()
  1311. if err != nil {
  1312. return users, err
  1313. }
  1314. users, err = getUsersWithVirtualFolders(ctx, users, dbHandle)
  1315. if err != nil {
  1316. return users, err
  1317. }
  1318. users, err = getUsersWithGroups(ctx, users, dbHandle)
  1319. if err != nil {
  1320. return users, err
  1321. }
  1322. for idx := range users {
  1323. users[idx].PrepareForRendering()
  1324. }
  1325. return users, nil
  1326. }
  1327. func sqlCommonGetDefenderHosts(from int64, limit int, dbHandle sqlQuerier) ([]DefenderEntry, error) {
  1328. hosts := make([]DefenderEntry, 0, 100)
  1329. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1330. defer cancel()
  1331. q := getDefenderHostsQuery()
  1332. rows, err := dbHandle.QueryContext(ctx, q, from, limit)
  1333. if err != nil {
  1334. providerLog(logger.LevelError, "unable to get defender hosts: %v", err)
  1335. return hosts, err
  1336. }
  1337. defer rows.Close()
  1338. var idForScores []int64
  1339. for rows.Next() {
  1340. var banTime sql.NullInt64
  1341. host := DefenderEntry{}
  1342. err = rows.Scan(&host.ID, &host.IP, &banTime)
  1343. if err != nil {
  1344. providerLog(logger.LevelError, "unable to scan defender host row: %v", err)
  1345. return hosts, err
  1346. }
  1347. var hostBanTime time.Time
  1348. if banTime.Valid && banTime.Int64 > 0 {
  1349. hostBanTime = util.GetTimeFromMsecSinceEpoch(banTime.Int64)
  1350. }
  1351. if hostBanTime.IsZero() || hostBanTime.Before(time.Now()) {
  1352. idForScores = append(idForScores, host.ID)
  1353. } else {
  1354. host.BanTime = hostBanTime
  1355. }
  1356. hosts = append(hosts, host)
  1357. }
  1358. err = rows.Err()
  1359. if err != nil {
  1360. providerLog(logger.LevelError, "unable to iterate over defender host rows: %v", err)
  1361. return hosts, err
  1362. }
  1363. return getDefenderHostsWithScores(ctx, hosts, from, idForScores, dbHandle)
  1364. }
  1365. func sqlCommonIsDefenderHostBanned(ip string, dbHandle sqlQuerier) (DefenderEntry, error) {
  1366. var host DefenderEntry
  1367. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1368. defer cancel()
  1369. q := getDefenderIsHostBannedQuery()
  1370. row := dbHandle.QueryRowContext(ctx, q, ip, util.GetTimeAsMsSinceEpoch(time.Now()))
  1371. err := row.Scan(&host.ID)
  1372. if err != nil {
  1373. if errors.Is(err, sql.ErrNoRows) {
  1374. return host, util.NewRecordNotFoundError("host not found")
  1375. }
  1376. providerLog(logger.LevelError, "unable to check ban status for host %#v: %v", ip, err)
  1377. return host, err
  1378. }
  1379. return host, nil
  1380. }
  1381. func sqlCommonGetDefenderHostByIP(ip string, from int64, dbHandle sqlQuerier) (DefenderEntry, error) {
  1382. var host DefenderEntry
  1383. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1384. defer cancel()
  1385. q := getDefenderHostQuery()
  1386. row := dbHandle.QueryRowContext(ctx, q, ip, from)
  1387. var banTime sql.NullInt64
  1388. err := row.Scan(&host.ID, &host.IP, &banTime)
  1389. if err != nil {
  1390. if errors.Is(err, sql.ErrNoRows) {
  1391. return host, util.NewRecordNotFoundError("host not found")
  1392. }
  1393. providerLog(logger.LevelError, "unable to get host for ip %#v: %v", ip, err)
  1394. return host, err
  1395. }
  1396. if banTime.Valid && banTime.Int64 > 0 {
  1397. hostBanTime := util.GetTimeFromMsecSinceEpoch(banTime.Int64)
  1398. if !hostBanTime.IsZero() && hostBanTime.After(time.Now()) {
  1399. host.BanTime = hostBanTime
  1400. return host, nil
  1401. }
  1402. }
  1403. hosts, err := getDefenderHostsWithScores(ctx, []DefenderEntry{host}, from, []int64{host.ID}, dbHandle)
  1404. if err != nil {
  1405. return host, err
  1406. }
  1407. if len(hosts) == 0 {
  1408. return host, util.NewRecordNotFoundError("host not found")
  1409. }
  1410. return hosts[0], nil
  1411. }
  1412. func sqlCommonDefenderIncrementBanTime(ip string, minutesToAdd int, dbHandle *sql.DB) error {
  1413. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1414. defer cancel()
  1415. q := getDefenderIncrementBanTimeQuery()
  1416. _, err := dbHandle.ExecContext(ctx, q, minutesToAdd*60000, ip)
  1417. if err == nil {
  1418. providerLog(logger.LevelDebug, "ban time updated for ip %#v, increment (minutes): %v",
  1419. ip, minutesToAdd)
  1420. } else {
  1421. providerLog(logger.LevelError, "error updating ban time for ip %#v: %v", ip, err)
  1422. }
  1423. return err
  1424. }
  1425. func sqlCommonSetDefenderBanTime(ip string, banTime int64, dbHandle *sql.DB) error {
  1426. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1427. defer cancel()
  1428. q := getDefenderSetBanTimeQuery()
  1429. _, err := dbHandle.ExecContext(ctx, q, banTime, ip)
  1430. if err == nil {
  1431. providerLog(logger.LevelDebug, "ip %#v banned until %v", ip, util.GetTimeFromMsecSinceEpoch(banTime))
  1432. } else {
  1433. providerLog(logger.LevelError, "error setting ban time for ip %#v: %v", ip, err)
  1434. }
  1435. return err
  1436. }
  1437. func sqlCommonDeleteDefenderHost(ip string, dbHandle sqlQuerier) error {
  1438. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1439. defer cancel()
  1440. q := getDeleteDefenderHostQuery()
  1441. res, err := dbHandle.ExecContext(ctx, q, ip)
  1442. if err != nil {
  1443. providerLog(logger.LevelError, "unable to delete defender host %#v: %v", ip, err)
  1444. return err
  1445. }
  1446. return sqlCommonRequireRowAffected(res)
  1447. }
  1448. func sqlCommonAddDefenderHostAndEvent(ip string, score int, dbHandle *sql.DB) error {
  1449. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1450. defer cancel()
  1451. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  1452. if err := sqlCommonAddDefenderHost(ctx, ip, tx); err != nil {
  1453. return err
  1454. }
  1455. return sqlCommonAddDefenderEvent(ctx, ip, score, tx)
  1456. })
  1457. }
  1458. func sqlCommonDefenderCleanup(from int64, dbHandler *sql.DB) error {
  1459. if err := sqlCommonCleanupDefenderEvents(from, dbHandler); err != nil {
  1460. return err
  1461. }
  1462. return sqlCommonCleanupDefenderHosts(from, dbHandler)
  1463. }
  1464. func sqlCommonAddDefenderHost(ctx context.Context, ip string, tx *sql.Tx) error {
  1465. q := getAddDefenderHostQuery()
  1466. _, err := tx.ExecContext(ctx, q, ip, util.GetTimeAsMsSinceEpoch(time.Now()))
  1467. if err != nil {
  1468. providerLog(logger.LevelError, "unable to add defender host %#v: %v", ip, err)
  1469. }
  1470. return err
  1471. }
  1472. func sqlCommonAddDefenderEvent(ctx context.Context, ip string, score int, tx *sql.Tx) error {
  1473. q := getAddDefenderEventQuery()
  1474. _, err := tx.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), score, ip)
  1475. if err != nil {
  1476. providerLog(logger.LevelError, "unable to add defender event for %#v: %v", ip, err)
  1477. }
  1478. return err
  1479. }
  1480. func sqlCommonCleanupDefenderHosts(from int64, dbHandle *sql.DB) error {
  1481. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1482. defer cancel()
  1483. q := getDefenderHostsCleanupQuery()
  1484. _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), from)
  1485. if err != nil {
  1486. providerLog(logger.LevelError, "unable to cleanup defender hosts: %v", err)
  1487. }
  1488. return err
  1489. }
  1490. func sqlCommonCleanupDefenderEvents(from int64, dbHandle *sql.DB) error {
  1491. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1492. defer cancel()
  1493. q := getDefenderEventsCleanupQuery()
  1494. _, err := dbHandle.ExecContext(ctx, q, from)
  1495. if err != nil {
  1496. providerLog(logger.LevelError, "unable to cleanup defender events: %v", err)
  1497. }
  1498. return err
  1499. }
  1500. func getShareFromDbRow(row sqlScanner) (Share, error) {
  1501. var share Share
  1502. var description, password sql.NullString
  1503. var allowFrom, paths []byte
  1504. err := row.Scan(&share.ShareID, &share.Name, &description, &share.Scope,
  1505. &paths, &share.Username, &share.CreatedAt, &share.UpdatedAt,
  1506. &share.LastUseAt, &share.ExpiresAt, &password, &share.MaxTokens,
  1507. &share.UsedTokens, &allowFrom)
  1508. if err != nil {
  1509. if errors.Is(err, sql.ErrNoRows) {
  1510. return share, util.NewRecordNotFoundError(err.Error())
  1511. }
  1512. return share, err
  1513. }
  1514. var list []string
  1515. err = json.Unmarshal(paths, &list)
  1516. if err != nil {
  1517. return share, err
  1518. }
  1519. share.Paths = list
  1520. if description.Valid {
  1521. share.Description = description.String
  1522. }
  1523. if password.Valid {
  1524. share.Password = password.String
  1525. }
  1526. list = nil
  1527. err = json.Unmarshal(allowFrom, &list)
  1528. if err == nil {
  1529. share.AllowFrom = list
  1530. }
  1531. return share, nil
  1532. }
  1533. func getAPIKeyFromDbRow(row sqlScanner) (APIKey, error) {
  1534. var apiKey APIKey
  1535. var userID, adminID sql.NullInt64
  1536. var description sql.NullString
  1537. err := row.Scan(&apiKey.KeyID, &apiKey.Name, &apiKey.Key, &apiKey.Scope, &apiKey.CreatedAt, &apiKey.UpdatedAt,
  1538. &apiKey.LastUseAt, &apiKey.ExpiresAt, &description, &userID, &adminID)
  1539. if err != nil {
  1540. if errors.Is(err, sql.ErrNoRows) {
  1541. return apiKey, util.NewRecordNotFoundError(err.Error())
  1542. }
  1543. return apiKey, err
  1544. }
  1545. if userID.Valid {
  1546. apiKey.userID = userID.Int64
  1547. }
  1548. if adminID.Valid {
  1549. apiKey.adminID = adminID.Int64
  1550. }
  1551. if description.Valid {
  1552. apiKey.Description = description.String
  1553. }
  1554. return apiKey, nil
  1555. }
  1556. func getAdminFromDbRow(row sqlScanner) (Admin, error) {
  1557. var admin Admin
  1558. var email, additionalInfo, description, role sql.NullString
  1559. var permissions, filters []byte
  1560. err := row.Scan(&admin.ID, &admin.Username, &admin.Password, &admin.Status, &email, &permissions,
  1561. &filters, &additionalInfo, &description, &admin.CreatedAt, &admin.UpdatedAt, &admin.LastLogin, &role)
  1562. if err != nil {
  1563. if errors.Is(err, sql.ErrNoRows) {
  1564. return admin, util.NewRecordNotFoundError(err.Error())
  1565. }
  1566. return admin, err
  1567. }
  1568. var perms []string
  1569. err = json.Unmarshal(permissions, &perms)
  1570. if err != nil {
  1571. return admin, err
  1572. }
  1573. admin.Permissions = perms
  1574. if email.Valid {
  1575. admin.Email = email.String
  1576. }
  1577. var adminFilters AdminFilters
  1578. err = json.Unmarshal(filters, &adminFilters)
  1579. if err == nil {
  1580. admin.Filters = adminFilters
  1581. }
  1582. if additionalInfo.Valid {
  1583. admin.AdditionalInfo = additionalInfo.String
  1584. }
  1585. if description.Valid {
  1586. admin.Description = description.String
  1587. }
  1588. if role.Valid {
  1589. admin.Role = role.String
  1590. }
  1591. admin.SetEmptySecretsIfNil()
  1592. return admin, nil
  1593. }
  1594. func getEventActionFromDbRow(row sqlScanner) (BaseEventAction, error) {
  1595. var action BaseEventAction
  1596. var description sql.NullString
  1597. var options []byte
  1598. err := row.Scan(&action.ID, &action.Name, &description, &action.Type, &options)
  1599. if err != nil {
  1600. if errors.Is(err, sql.ErrNoRows) {
  1601. return action, util.NewRecordNotFoundError(err.Error())
  1602. }
  1603. return action, err
  1604. }
  1605. if description.Valid {
  1606. action.Description = description.String
  1607. }
  1608. var actionOptions BaseEventActionOptions
  1609. err = json.Unmarshal(options, &actionOptions)
  1610. if err == nil {
  1611. action.Options = actionOptions
  1612. }
  1613. return action, nil
  1614. }
  1615. func getEventRuleFromDbRow(row sqlScanner) (EventRule, error) {
  1616. var rule EventRule
  1617. var description sql.NullString
  1618. var conditions []byte
  1619. err := row.Scan(&rule.ID, &rule.Name, &description, &rule.CreatedAt, &rule.UpdatedAt, &rule.Trigger,
  1620. &conditions, &rule.DeletedAt)
  1621. if err != nil {
  1622. if errors.Is(err, sql.ErrNoRows) {
  1623. return rule, util.NewRecordNotFoundError(err.Error())
  1624. }
  1625. return rule, err
  1626. }
  1627. var ruleConditions EventConditions
  1628. err = json.Unmarshal(conditions, &ruleConditions)
  1629. if err == nil {
  1630. rule.Conditions = ruleConditions
  1631. }
  1632. if description.Valid {
  1633. rule.Description = description.String
  1634. }
  1635. return rule, nil
  1636. }
  1637. func getRoleFromDbRow(row sqlScanner) (Role, error) {
  1638. var role Role
  1639. var description sql.NullString
  1640. err := row.Scan(&role.ID, &role.Name, &description, &role.CreatedAt, &role.UpdatedAt)
  1641. if err != nil {
  1642. if errors.Is(err, sql.ErrNoRows) {
  1643. return role, util.NewRecordNotFoundError(err.Error())
  1644. }
  1645. return role, err
  1646. }
  1647. if description.Valid {
  1648. role.Description = description.String
  1649. }
  1650. return role, nil
  1651. }
  1652. func getGroupFromDbRow(row sqlScanner) (Group, error) {
  1653. var group Group
  1654. var description sql.NullString
  1655. var userSettings []byte
  1656. err := row.Scan(&group.ID, &group.Name, &description, &group.CreatedAt, &group.UpdatedAt, &userSettings)
  1657. if err != nil {
  1658. if errors.Is(err, sql.ErrNoRows) {
  1659. return group, util.NewRecordNotFoundError(err.Error())
  1660. }
  1661. return group, err
  1662. }
  1663. if description.Valid {
  1664. group.Description = description.String
  1665. }
  1666. var settings GroupUserSettings
  1667. err = json.Unmarshal(userSettings, &settings)
  1668. if err == nil {
  1669. group.UserSettings = settings
  1670. }
  1671. return group, nil
  1672. }
  1673. func getUserFromDbRow(row sqlScanner) (User, error) {
  1674. var user User
  1675. var password sql.NullString
  1676. var permissions, publicKey, filters, fsConfig []byte
  1677. var additionalInfo, description, email, role sql.NullString
  1678. err := row.Scan(&user.ID, &user.Username, &password, &publicKey, &user.HomeDir, &user.UID, &user.GID, &user.MaxSessions,
  1679. &user.QuotaSize, &user.QuotaFiles, &permissions, &user.UsedQuotaSize, &user.UsedQuotaFiles, &user.LastQuotaUpdate,
  1680. &user.UploadBandwidth, &user.DownloadBandwidth, &user.ExpirationDate, &user.LastLogin, &user.Status, &filters, &fsConfig,
  1681. &additionalInfo, &description, &email, &user.CreatedAt, &user.UpdatedAt, &user.UploadDataTransfer, &user.DownloadDataTransfer,
  1682. &user.TotalDataTransfer, &user.UsedUploadDataTransfer, &user.UsedDownloadDataTransfer, &user.DeletedAt, &user.FirstDownload,
  1683. &user.FirstUpload, &role, &user.LastPasswordChange)
  1684. if err != nil {
  1685. if errors.Is(err, sql.ErrNoRows) {
  1686. return user, util.NewRecordNotFoundError(err.Error())
  1687. }
  1688. return user, err
  1689. }
  1690. if password.Valid {
  1691. user.Password = password.String
  1692. }
  1693. perms := make(map[string][]string)
  1694. err = json.Unmarshal(permissions, &perms)
  1695. if err != nil {
  1696. providerLog(logger.LevelError, "unable to deserialize permissions for user %#v: %v", user.Username, err)
  1697. return user, fmt.Errorf("unable to deserialize permissions for user %#v: %v", user.Username, err)
  1698. }
  1699. user.Permissions = perms
  1700. // we can have a empty string or an invalid json in null string
  1701. // so we do a relaxed test if the field is optional, for example we
  1702. // populate public keys only if unmarshal does not return an error
  1703. var pKeys []string
  1704. err = json.Unmarshal(publicKey, &pKeys)
  1705. if err == nil {
  1706. user.PublicKeys = pKeys
  1707. }
  1708. var userFilters UserFilters
  1709. err = json.Unmarshal(filters, &userFilters)
  1710. if err == nil {
  1711. user.Filters = userFilters
  1712. }
  1713. var fs vfs.Filesystem
  1714. err = json.Unmarshal(fsConfig, &fs)
  1715. if err == nil {
  1716. user.FsConfig = fs
  1717. }
  1718. if additionalInfo.Valid {
  1719. user.AdditionalInfo = additionalInfo.String
  1720. }
  1721. if description.Valid {
  1722. user.Description = description.String
  1723. }
  1724. if email.Valid {
  1725. user.Email = email.String
  1726. }
  1727. if role.Valid {
  1728. user.Role = role.String
  1729. }
  1730. user.SetEmptySecretsIfNil()
  1731. return user, nil
  1732. }
  1733. func sqlCommonGetFolder(ctx context.Context, name string, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) {
  1734. var folder vfs.BaseVirtualFolder
  1735. q := getFolderByNameQuery()
  1736. row := dbHandle.QueryRowContext(ctx, q, name)
  1737. var mappedPath, description sql.NullString
  1738. var fsConfig []byte
  1739. err := row.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles, &folder.LastQuotaUpdate,
  1740. &folder.Name, &description, &fsConfig)
  1741. if err != nil {
  1742. if errors.Is(err, sql.ErrNoRows) {
  1743. return folder, util.NewRecordNotFoundError(err.Error())
  1744. }
  1745. return folder, err
  1746. }
  1747. if mappedPath.Valid {
  1748. folder.MappedPath = mappedPath.String
  1749. }
  1750. if description.Valid {
  1751. folder.Description = description.String
  1752. }
  1753. var fs vfs.Filesystem
  1754. err = json.Unmarshal(fsConfig, &fs)
  1755. if err == nil {
  1756. folder.FsConfig = fs
  1757. }
  1758. return folder, err
  1759. }
  1760. func sqlCommonGetFolderByName(ctx context.Context, name string, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) {
  1761. folder, err := sqlCommonGetFolder(ctx, name, dbHandle)
  1762. if err != nil {
  1763. return folder, err
  1764. }
  1765. folders, err := getVirtualFoldersWithUsers([]vfs.BaseVirtualFolder{folder}, dbHandle)
  1766. if err != nil {
  1767. return folder, err
  1768. }
  1769. if len(folders) != 1 {
  1770. return folder, fmt.Errorf("unable to associate users with folder %#v", name)
  1771. }
  1772. folders, err = getVirtualFoldersWithGroups([]vfs.BaseVirtualFolder{folders[0]}, dbHandle)
  1773. if err != nil {
  1774. return folder, err
  1775. }
  1776. if len(folders) != 1 {
  1777. return folder, fmt.Errorf("unable to associate groups with folder %#v", name)
  1778. }
  1779. return folders[0], nil
  1780. }
  1781. func sqlCommonAddOrUpdateFolder(ctx context.Context, baseFolder *vfs.BaseVirtualFolder, usedQuotaSize int64,
  1782. usedQuotaFiles int, lastQuotaUpdate int64, dbHandle sqlQuerier,
  1783. ) error {
  1784. fsConfig, err := json.Marshal(baseFolder.FsConfig)
  1785. if err != nil {
  1786. return err
  1787. }
  1788. q := getUpsertFolderQuery()
  1789. _, err = dbHandle.ExecContext(ctx, q, baseFolder.MappedPath, usedQuotaSize, usedQuotaFiles,
  1790. lastQuotaUpdate, baseFolder.Name, baseFolder.Description, string(fsConfig))
  1791. return err
  1792. }
  1793. func sqlCommonAddFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
  1794. err := ValidateFolder(folder)
  1795. if err != nil {
  1796. return err
  1797. }
  1798. fsConfig, err := json.Marshal(folder.FsConfig)
  1799. if err != nil {
  1800. return err
  1801. }
  1802. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1803. defer cancel()
  1804. q := getAddFolderQuery()
  1805. _, err = dbHandle.ExecContext(ctx, q, folder.MappedPath, folder.UsedQuotaSize, folder.UsedQuotaFiles,
  1806. folder.LastQuotaUpdate, folder.Name, folder.Description, string(fsConfig))
  1807. return err
  1808. }
  1809. func sqlCommonUpdateFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
  1810. err := ValidateFolder(folder)
  1811. if err != nil {
  1812. return err
  1813. }
  1814. fsConfig, err := json.Marshal(folder.FsConfig)
  1815. if err != nil {
  1816. return err
  1817. }
  1818. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1819. defer cancel()
  1820. q := getUpdateFolderQuery()
  1821. _, err = dbHandle.ExecContext(ctx, q, folder.MappedPath, folder.Description, string(fsConfig), folder.Name)
  1822. return err
  1823. }
  1824. func sqlCommonDeleteFolder(folder vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
  1825. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1826. defer cancel()
  1827. q := getDeleteFolderQuery()
  1828. res, err := dbHandle.ExecContext(ctx, q, folder.ID)
  1829. if err != nil {
  1830. return err
  1831. }
  1832. return sqlCommonRequireRowAffected(res)
  1833. }
  1834. func sqlCommonDumpFolders(dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) {
  1835. folders := make([]vfs.BaseVirtualFolder, 0, 50)
  1836. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  1837. defer cancel()
  1838. q := getDumpFoldersQuery()
  1839. rows, err := dbHandle.QueryContext(ctx, q)
  1840. if err != nil {
  1841. return folders, err
  1842. }
  1843. defer rows.Close()
  1844. for rows.Next() {
  1845. var folder vfs.BaseVirtualFolder
  1846. var mappedPath, description sql.NullString
  1847. var fsConfig []byte
  1848. err = rows.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles,
  1849. &folder.LastQuotaUpdate, &folder.Name, &description, &fsConfig)
  1850. if err != nil {
  1851. return folders, err
  1852. }
  1853. if mappedPath.Valid {
  1854. folder.MappedPath = mappedPath.String
  1855. }
  1856. if description.Valid {
  1857. folder.Description = description.String
  1858. }
  1859. var fs vfs.Filesystem
  1860. err = json.Unmarshal(fsConfig, &fs)
  1861. if err == nil {
  1862. folder.FsConfig = fs
  1863. }
  1864. folders = append(folders, folder)
  1865. }
  1866. return folders, rows.Err()
  1867. }
  1868. func sqlCommonGetFolders(limit, offset int, order string, minimal bool, dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) {
  1869. folders := make([]vfs.BaseVirtualFolder, 0, limit)
  1870. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  1871. defer cancel()
  1872. q := getFoldersQuery(order, minimal)
  1873. rows, err := dbHandle.QueryContext(ctx, q, limit, offset)
  1874. if err != nil {
  1875. return folders, err
  1876. }
  1877. defer rows.Close()
  1878. for rows.Next() {
  1879. var folder vfs.BaseVirtualFolder
  1880. if minimal {
  1881. err = rows.Scan(&folder.ID, &folder.Name)
  1882. if err != nil {
  1883. return folders, err
  1884. }
  1885. } else {
  1886. var mappedPath, description sql.NullString
  1887. var fsConfig []byte
  1888. err = rows.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles,
  1889. &folder.LastQuotaUpdate, &folder.Name, &description, &fsConfig)
  1890. if err != nil {
  1891. return folders, err
  1892. }
  1893. if mappedPath.Valid {
  1894. folder.MappedPath = mappedPath.String
  1895. }
  1896. if description.Valid {
  1897. folder.Description = description.String
  1898. }
  1899. var fs vfs.Filesystem
  1900. err = json.Unmarshal(fsConfig, &fs)
  1901. if err == nil {
  1902. folder.FsConfig = fs
  1903. }
  1904. }
  1905. folder.PrepareForRendering()
  1906. folders = append(folders, folder)
  1907. }
  1908. err = rows.Err()
  1909. if err != nil {
  1910. return folders, err
  1911. }
  1912. if minimal {
  1913. return folders, nil
  1914. }
  1915. folders, err = getVirtualFoldersWithUsers(folders, dbHandle)
  1916. if err != nil {
  1917. return folders, err
  1918. }
  1919. return getVirtualFoldersWithGroups(folders, dbHandle)
  1920. }
  1921. func sqlCommonClearUserFolderMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error {
  1922. q := getClearUserFolderMappingQuery()
  1923. _, err := dbHandle.ExecContext(ctx, q, user.Username)
  1924. return err
  1925. }
  1926. func sqlCommonClearGroupFolderMapping(ctx context.Context, group *Group, dbHandle sqlQuerier) error {
  1927. q := getClearGroupFolderMappingQuery()
  1928. _, err := dbHandle.ExecContext(ctx, q, group.Name)
  1929. return err
  1930. }
  1931. func sqlCommonClearUserGroupMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error {
  1932. q := getClearUserGroupMappingQuery()
  1933. _, err := dbHandle.ExecContext(ctx, q, user.Username)
  1934. return err
  1935. }
  1936. func sqlCommonAddUserFolderMapping(ctx context.Context, user *User, folder *vfs.VirtualFolder, dbHandle sqlQuerier) error {
  1937. q := getAddUserFolderMappingQuery()
  1938. _, err := dbHandle.ExecContext(ctx, q, folder.VirtualPath, folder.QuotaSize, folder.QuotaFiles, folder.Name, user.Username)
  1939. return err
  1940. }
  1941. func sqlCommonClearAdminGroupMapping(ctx context.Context, admin *Admin, dbHandle sqlQuerier) error {
  1942. q := getClearAdminGroupMappingQuery()
  1943. _, err := dbHandle.ExecContext(ctx, q, admin.Username)
  1944. return err
  1945. }
  1946. func sqlCommonAddGroupFolderMapping(ctx context.Context, group *Group, folder *vfs.VirtualFolder, dbHandle sqlQuerier) error {
  1947. q := getAddGroupFolderMappingQuery()
  1948. _, err := dbHandle.ExecContext(ctx, q, folder.VirtualPath, folder.QuotaSize, folder.QuotaFiles, folder.Name, group.Name)
  1949. return err
  1950. }
  1951. func sqlCommonAddUserGroupMapping(ctx context.Context, username, groupName string, groupType int, dbHandle sqlQuerier) error {
  1952. q := getAddUserGroupMappingQuery()
  1953. _, err := dbHandle.ExecContext(ctx, q, username, groupName, groupType)
  1954. return err
  1955. }
  1956. func sqlCommonAddAdminGroupMapping(ctx context.Context, username, groupName string, mappingOptions AdminGroupMappingOptions,
  1957. dbHandle sqlQuerier,
  1958. ) error {
  1959. options, err := json.Marshal(mappingOptions)
  1960. if err != nil {
  1961. return err
  1962. }
  1963. q := getAddAdminGroupMappingQuery()
  1964. _, err = dbHandle.ExecContext(ctx, q, username, groupName, string(options))
  1965. return err
  1966. }
  1967. func generateGroupVirtualFoldersMapping(ctx context.Context, group *Group, dbHandle sqlQuerier) error {
  1968. err := sqlCommonClearGroupFolderMapping(ctx, group, dbHandle)
  1969. if err != nil {
  1970. return err
  1971. }
  1972. for idx := range group.VirtualFolders {
  1973. vfolder := &group.VirtualFolders[idx]
  1974. err = sqlCommonAddOrUpdateFolder(ctx, &vfolder.BaseVirtualFolder, 0, 0, 0, dbHandle)
  1975. if err != nil {
  1976. return err
  1977. }
  1978. err = sqlCommonAddGroupFolderMapping(ctx, group, vfolder, dbHandle)
  1979. if err != nil {
  1980. return err
  1981. }
  1982. }
  1983. return err
  1984. }
  1985. func generateUserVirtualFoldersMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error {
  1986. err := sqlCommonClearUserFolderMapping(ctx, user, dbHandle)
  1987. if err != nil {
  1988. return err
  1989. }
  1990. for idx := range user.VirtualFolders {
  1991. vfolder := &user.VirtualFolders[idx]
  1992. err := sqlCommonAddOrUpdateFolder(ctx, &vfolder.BaseVirtualFolder, 0, 0, 0, dbHandle)
  1993. if err != nil {
  1994. return err
  1995. }
  1996. err = sqlCommonAddUserFolderMapping(ctx, user, vfolder, dbHandle)
  1997. if err != nil {
  1998. return err
  1999. }
  2000. }
  2001. return err
  2002. }
  2003. func generateUserGroupMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error {
  2004. err := sqlCommonClearUserGroupMapping(ctx, user, dbHandle)
  2005. if err != nil {
  2006. return err
  2007. }
  2008. for _, group := range user.Groups {
  2009. err = sqlCommonAddUserGroupMapping(ctx, user.Username, group.Name, group.Type, dbHandle)
  2010. if err != nil {
  2011. return err
  2012. }
  2013. }
  2014. return err
  2015. }
  2016. func generateAdminGroupMapping(ctx context.Context, admin *Admin, dbHandle sqlQuerier) error {
  2017. err := sqlCommonClearAdminGroupMapping(ctx, admin, dbHandle)
  2018. if err != nil {
  2019. return err
  2020. }
  2021. for _, group := range admin.Groups {
  2022. err = sqlCommonAddAdminGroupMapping(ctx, admin.Username, group.Name, group.Options, dbHandle)
  2023. if err != nil {
  2024. return err
  2025. }
  2026. }
  2027. return err
  2028. }
  2029. func getDefenderHostsWithScores(ctx context.Context, hosts []DefenderEntry, from int64, idForScores []int64,
  2030. dbHandle sqlQuerier) (
  2031. []DefenderEntry,
  2032. error,
  2033. ) {
  2034. if len(idForScores) == 0 {
  2035. return hosts, nil
  2036. }
  2037. hostsWithScores := make(map[int64]int)
  2038. q := getDefenderEventsQuery(idForScores)
  2039. rows, err := dbHandle.QueryContext(ctx, q, from)
  2040. if err != nil {
  2041. providerLog(logger.LevelError, "unable to get score for hosts with id %+v: %v", idForScores, err)
  2042. return nil, err
  2043. }
  2044. defer rows.Close()
  2045. for rows.Next() {
  2046. var hostID int64
  2047. var score int
  2048. err = rows.Scan(&hostID, &score)
  2049. if err != nil {
  2050. providerLog(logger.LevelError, "error scanning host score row: %v", err)
  2051. return hosts, err
  2052. }
  2053. if score > 0 {
  2054. hostsWithScores[hostID] = score
  2055. }
  2056. }
  2057. err = rows.Err()
  2058. if err != nil {
  2059. return hosts, err
  2060. }
  2061. result := make([]DefenderEntry, 0, len(hosts))
  2062. for idx := range hosts {
  2063. hosts[idx].Score = hostsWithScores[hosts[idx].ID]
  2064. if hosts[idx].Score > 0 || !hosts[idx].BanTime.IsZero() {
  2065. result = append(result, hosts[idx])
  2066. }
  2067. }
  2068. return result, nil
  2069. }
  2070. func getAdminWithGroups(ctx context.Context, admin Admin, dbHandle sqlQuerier) (Admin, error) {
  2071. admins, err := getAdminsWithGroups(ctx, []Admin{admin}, dbHandle)
  2072. if err != nil {
  2073. return admin, err
  2074. }
  2075. if len(admins) == 0 {
  2076. return admin, errSQLGroupsAssociation
  2077. }
  2078. return admins[0], err
  2079. }
  2080. func getAdminsWithGroups(ctx context.Context, admins []Admin, dbHandle sqlQuerier) ([]Admin, error) {
  2081. if len(admins) == 0 {
  2082. return admins, nil
  2083. }
  2084. adminsGroups := make(map[int64][]AdminGroupMapping)
  2085. q := getRelatedGroupsForAdminsQuery(admins)
  2086. rows, err := dbHandle.QueryContext(ctx, q)
  2087. if err != nil {
  2088. return nil, err
  2089. }
  2090. defer rows.Close()
  2091. for rows.Next() {
  2092. var group AdminGroupMapping
  2093. var adminID int64
  2094. var options []byte
  2095. err = rows.Scan(&group.Name, &options, &adminID)
  2096. if err != nil {
  2097. return admins, err
  2098. }
  2099. err = json.Unmarshal(options, &group.Options)
  2100. if err != nil {
  2101. return admins, err
  2102. }
  2103. adminsGroups[adminID] = append(adminsGroups[adminID], group)
  2104. }
  2105. err = rows.Err()
  2106. if err != nil {
  2107. return admins, err
  2108. }
  2109. if len(adminsGroups) == 0 {
  2110. return admins, err
  2111. }
  2112. for idx := range admins {
  2113. ref := &admins[idx]
  2114. ref.Groups = adminsGroups[ref.ID]
  2115. }
  2116. return admins, err
  2117. }
  2118. func getUserWithVirtualFolders(ctx context.Context, user User, dbHandle sqlQuerier) (User, error) {
  2119. users, err := getUsersWithVirtualFolders(ctx, []User{user}, dbHandle)
  2120. if err != nil {
  2121. return user, err
  2122. }
  2123. if len(users) == 0 {
  2124. return user, errSQLFoldersAssociation
  2125. }
  2126. return users[0], err
  2127. }
  2128. func getUsersWithVirtualFolders(ctx context.Context, users []User, dbHandle sqlQuerier) ([]User, error) {
  2129. if len(users) == 0 {
  2130. return users, nil
  2131. }
  2132. usersVirtualFolders := make(map[int64][]vfs.VirtualFolder)
  2133. q := getRelatedFoldersForUsersQuery(users)
  2134. rows, err := dbHandle.QueryContext(ctx, q)
  2135. if err != nil {
  2136. return nil, err
  2137. }
  2138. defer rows.Close()
  2139. for rows.Next() {
  2140. var folder vfs.VirtualFolder
  2141. var userID int64
  2142. var mappedPath, description sql.NullString
  2143. var fsConfig []byte
  2144. err = rows.Scan(&folder.ID, &folder.Name, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles,
  2145. &folder.LastQuotaUpdate, &folder.VirtualPath, &folder.QuotaSize, &folder.QuotaFiles, &userID, &fsConfig,
  2146. &description)
  2147. if err != nil {
  2148. return users, err
  2149. }
  2150. if mappedPath.Valid {
  2151. folder.MappedPath = mappedPath.String
  2152. }
  2153. if description.Valid {
  2154. folder.Description = description.String
  2155. }
  2156. var fs vfs.Filesystem
  2157. err = json.Unmarshal(fsConfig, &fs)
  2158. if err == nil {
  2159. folder.FsConfig = fs
  2160. }
  2161. usersVirtualFolders[userID] = append(usersVirtualFolders[userID], folder)
  2162. }
  2163. err = rows.Err()
  2164. if err != nil {
  2165. return users, err
  2166. }
  2167. if len(usersVirtualFolders) == 0 {
  2168. return users, err
  2169. }
  2170. for idx := range users {
  2171. ref := &users[idx]
  2172. ref.VirtualFolders = usersVirtualFolders[ref.ID]
  2173. }
  2174. return users, err
  2175. }
  2176. func getUserWithGroups(ctx context.Context, user User, dbHandle sqlQuerier) (User, error) {
  2177. users, err := getUsersWithGroups(ctx, []User{user}, dbHandle)
  2178. if err != nil {
  2179. return user, err
  2180. }
  2181. if len(users) == 0 {
  2182. return user, errSQLGroupsAssociation
  2183. }
  2184. return users[0], err
  2185. }
  2186. func getUsersWithGroups(ctx context.Context, users []User, dbHandle sqlQuerier) ([]User, error) {
  2187. if len(users) == 0 {
  2188. return users, nil
  2189. }
  2190. usersGroups := make(map[int64][]sdk.GroupMapping)
  2191. q := getRelatedGroupsForUsersQuery(users)
  2192. rows, err := dbHandle.QueryContext(ctx, q)
  2193. if err != nil {
  2194. return nil, err
  2195. }
  2196. defer rows.Close()
  2197. for rows.Next() {
  2198. var group sdk.GroupMapping
  2199. var userID int64
  2200. err = rows.Scan(&group.Name, &group.Type, &userID)
  2201. if err != nil {
  2202. return users, err
  2203. }
  2204. usersGroups[userID] = append(usersGroups[userID], group)
  2205. }
  2206. err = rows.Err()
  2207. if err != nil {
  2208. return users, err
  2209. }
  2210. if len(usersGroups) == 0 {
  2211. return users, err
  2212. }
  2213. for idx := range users {
  2214. ref := &users[idx]
  2215. ref.Groups = usersGroups[ref.ID]
  2216. }
  2217. return users, err
  2218. }
  2219. func getGroupWithUsers(ctx context.Context, group Group, dbHandle sqlQuerier) (Group, error) {
  2220. groups, err := getGroupsWithUsers(ctx, []Group{group}, dbHandle)
  2221. if err != nil {
  2222. return group, err
  2223. }
  2224. if len(groups) == 0 {
  2225. return group, errSQLUsersAssociation
  2226. }
  2227. return groups[0], err
  2228. }
  2229. func getRoleWithUsers(ctx context.Context, role Role, dbHandle sqlQuerier) (Role, error) {
  2230. roles, err := getRolesWithUsers(ctx, []Role{role}, dbHandle)
  2231. if err != nil {
  2232. return role, err
  2233. }
  2234. if len(roles) == 0 {
  2235. return role, errors.New("unable to associate users with role")
  2236. }
  2237. return roles[0], err
  2238. }
  2239. func getRoleWithAdmins(ctx context.Context, role Role, dbHandle sqlQuerier) (Role, error) {
  2240. roles, err := getRolesWithAdmins(ctx, []Role{role}, dbHandle)
  2241. if err != nil {
  2242. return role, err
  2243. }
  2244. if len(roles) == 0 {
  2245. return role, errors.New("unable to associate admins with role")
  2246. }
  2247. return roles[0], err
  2248. }
  2249. func getGroupWithAdmins(ctx context.Context, group Group, dbHandle sqlQuerier) (Group, error) {
  2250. groups, err := getGroupsWithAdmins(ctx, []Group{group}, dbHandle)
  2251. if err != nil {
  2252. return group, err
  2253. }
  2254. if len(groups) == 0 {
  2255. return group, errSQLUsersAssociation
  2256. }
  2257. return groups[0], err
  2258. }
  2259. func getGroupWithVirtualFolders(ctx context.Context, group Group, dbHandle sqlQuerier) (Group, error) {
  2260. groups, err := getGroupsWithVirtualFolders(ctx, []Group{group}, dbHandle)
  2261. if err != nil {
  2262. return group, err
  2263. }
  2264. if len(groups) == 0 {
  2265. return group, errSQLFoldersAssociation
  2266. }
  2267. return groups[0], err
  2268. }
  2269. func getGroupsWithVirtualFolders(ctx context.Context, groups []Group, dbHandle sqlQuerier) ([]Group, error) {
  2270. if len(groups) == 0 {
  2271. return groups, nil
  2272. }
  2273. q := getRelatedFoldersForGroupsQuery(groups)
  2274. rows, err := dbHandle.QueryContext(ctx, q)
  2275. if err != nil {
  2276. return nil, err
  2277. }
  2278. defer rows.Close()
  2279. groupsVirtualFolders := make(map[int64][]vfs.VirtualFolder)
  2280. for rows.Next() {
  2281. var groupID int64
  2282. var folder vfs.VirtualFolder
  2283. var mappedPath, description sql.NullString
  2284. var fsConfig []byte
  2285. err = rows.Scan(&folder.ID, &folder.Name, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles,
  2286. &folder.LastQuotaUpdate, &folder.VirtualPath, &folder.QuotaSize, &folder.QuotaFiles, &groupID, &fsConfig,
  2287. &description)
  2288. if err != nil {
  2289. return groups, err
  2290. }
  2291. if mappedPath.Valid {
  2292. folder.MappedPath = mappedPath.String
  2293. }
  2294. if description.Valid {
  2295. folder.Description = description.String
  2296. }
  2297. var fs vfs.Filesystem
  2298. err = json.Unmarshal(fsConfig, &fs)
  2299. if err == nil {
  2300. folder.FsConfig = fs
  2301. }
  2302. groupsVirtualFolders[groupID] = append(groupsVirtualFolders[groupID], folder)
  2303. }
  2304. err = rows.Err()
  2305. if err != nil {
  2306. return groups, err
  2307. }
  2308. if len(groupsVirtualFolders) == 0 {
  2309. return groups, err
  2310. }
  2311. for idx := range groups {
  2312. ref := &groups[idx]
  2313. ref.VirtualFolders = groupsVirtualFolders[ref.ID]
  2314. }
  2315. return groups, err
  2316. }
  2317. func getGroupsWithUsers(ctx context.Context, groups []Group, dbHandle sqlQuerier) ([]Group, error) {
  2318. if len(groups) == 0 {
  2319. return groups, nil
  2320. }
  2321. q := getRelatedUsersForGroupsQuery(groups)
  2322. rows, err := dbHandle.QueryContext(ctx, q)
  2323. if err != nil {
  2324. return nil, err
  2325. }
  2326. defer rows.Close()
  2327. groupsUsers := make(map[int64][]string)
  2328. for rows.Next() {
  2329. var username string
  2330. var groupID int64
  2331. err = rows.Scan(&groupID, &username)
  2332. if err != nil {
  2333. return groups, err
  2334. }
  2335. groupsUsers[groupID] = append(groupsUsers[groupID], username)
  2336. }
  2337. err = rows.Err()
  2338. if err != nil {
  2339. return groups, err
  2340. }
  2341. if len(groupsUsers) == 0 {
  2342. return groups, err
  2343. }
  2344. for idx := range groups {
  2345. ref := &groups[idx]
  2346. ref.Users = groupsUsers[ref.ID]
  2347. }
  2348. return groups, err
  2349. }
  2350. func getRolesWithUsers(ctx context.Context, roles []Role, dbHandle sqlQuerier) ([]Role, error) {
  2351. if len(roles) == 0 {
  2352. return roles, nil
  2353. }
  2354. rows, err := dbHandle.QueryContext(ctx, getUsersWithRolesQuery(roles))
  2355. if err != nil {
  2356. return nil, err
  2357. }
  2358. defer rows.Close()
  2359. rolesUsers := make(map[int64][]string)
  2360. for rows.Next() {
  2361. var roleID int64
  2362. var username string
  2363. err = rows.Scan(&roleID, &username)
  2364. if err != nil {
  2365. return roles, err
  2366. }
  2367. rolesUsers[roleID] = append(rolesUsers[roleID], username)
  2368. }
  2369. err = rows.Err()
  2370. if err != nil {
  2371. return roles, err
  2372. }
  2373. if len(rolesUsers) > 0 {
  2374. for idx := range roles {
  2375. ref := &roles[idx]
  2376. ref.Users = rolesUsers[ref.ID]
  2377. }
  2378. }
  2379. return roles, nil
  2380. }
  2381. func getRolesWithAdmins(ctx context.Context, roles []Role, dbHandle sqlQuerier) ([]Role, error) {
  2382. if len(roles) == 0 {
  2383. return roles, nil
  2384. }
  2385. rows, err := dbHandle.QueryContext(ctx, getAdminsWithRolesQuery(roles))
  2386. if err != nil {
  2387. return nil, err
  2388. }
  2389. defer rows.Close()
  2390. rolesAdmins := make(map[int64][]string)
  2391. for rows.Next() {
  2392. var roleID int64
  2393. var username string
  2394. err = rows.Scan(&roleID, &username)
  2395. if err != nil {
  2396. return roles, err
  2397. }
  2398. rolesAdmins[roleID] = append(rolesAdmins[roleID], username)
  2399. }
  2400. if err = rows.Err(); err != nil {
  2401. return roles, err
  2402. }
  2403. if len(rolesAdmins) > 0 {
  2404. for idx := range roles {
  2405. ref := &roles[idx]
  2406. ref.Admins = rolesAdmins[ref.ID]
  2407. }
  2408. }
  2409. return roles, nil
  2410. }
  2411. func getGroupsWithAdmins(ctx context.Context, groups []Group, dbHandle sqlQuerier) ([]Group, error) {
  2412. if len(groups) == 0 {
  2413. return groups, nil
  2414. }
  2415. q := getRelatedAdminsForGroupsQuery(groups)
  2416. rows, err := dbHandle.QueryContext(ctx, q)
  2417. if err != nil {
  2418. return nil, err
  2419. }
  2420. defer rows.Close()
  2421. groupsAdmins := make(map[int64][]string)
  2422. for rows.Next() {
  2423. var groupID int64
  2424. var username string
  2425. err = rows.Scan(&groupID, &username)
  2426. if err != nil {
  2427. return groups, err
  2428. }
  2429. groupsAdmins[groupID] = append(groupsAdmins[groupID], username)
  2430. }
  2431. err = rows.Err()
  2432. if err != nil {
  2433. return groups, err
  2434. }
  2435. if len(groupsAdmins) > 0 {
  2436. for idx := range groups {
  2437. ref := &groups[idx]
  2438. ref.Admins = groupsAdmins[ref.ID]
  2439. }
  2440. }
  2441. return groups, nil
  2442. }
  2443. func getVirtualFoldersWithGroups(folders []vfs.BaseVirtualFolder, dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) {
  2444. if len(folders) == 0 {
  2445. return folders, nil
  2446. }
  2447. vFoldersGroups := make(map[int64][]string)
  2448. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2449. defer cancel()
  2450. q := getRelatedGroupsForFoldersQuery(folders)
  2451. rows, err := dbHandle.QueryContext(ctx, q)
  2452. if err != nil {
  2453. return nil, err
  2454. }
  2455. defer rows.Close()
  2456. for rows.Next() {
  2457. var name string
  2458. var folderID int64
  2459. err = rows.Scan(&folderID, &name)
  2460. if err != nil {
  2461. return folders, err
  2462. }
  2463. vFoldersGroups[folderID] = append(vFoldersGroups[folderID], name)
  2464. }
  2465. err = rows.Err()
  2466. if err != nil {
  2467. return folders, err
  2468. }
  2469. if len(vFoldersGroups) == 0 {
  2470. return folders, err
  2471. }
  2472. for idx := range folders {
  2473. ref := &folders[idx]
  2474. ref.Groups = vFoldersGroups[ref.ID]
  2475. }
  2476. return folders, err
  2477. }
  2478. func getVirtualFoldersWithUsers(folders []vfs.BaseVirtualFolder, dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) {
  2479. if len(folders) == 0 {
  2480. return folders, nil
  2481. }
  2482. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2483. defer cancel()
  2484. q := getRelatedUsersForFoldersQuery(folders)
  2485. rows, err := dbHandle.QueryContext(ctx, q)
  2486. if err != nil {
  2487. return nil, err
  2488. }
  2489. defer rows.Close()
  2490. vFoldersUsers := make(map[int64][]string)
  2491. for rows.Next() {
  2492. var username string
  2493. var folderID int64
  2494. err = rows.Scan(&folderID, &username)
  2495. if err != nil {
  2496. return folders, err
  2497. }
  2498. vFoldersUsers[folderID] = append(vFoldersUsers[folderID], username)
  2499. }
  2500. err = rows.Err()
  2501. if err != nil {
  2502. return folders, err
  2503. }
  2504. if len(vFoldersUsers) == 0 {
  2505. return folders, err
  2506. }
  2507. for idx := range folders {
  2508. ref := &folders[idx]
  2509. ref.Users = vFoldersUsers[ref.ID]
  2510. }
  2511. return folders, err
  2512. }
  2513. func sqlCommonUpdateFolderQuota(name string, filesAdd int, sizeAdd int64, reset bool, dbHandle *sql.DB) error {
  2514. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2515. defer cancel()
  2516. q := getUpdateFolderQuotaQuery(reset)
  2517. _, err := dbHandle.ExecContext(ctx, q, sizeAdd, filesAdd, util.GetTimeAsMsSinceEpoch(time.Now()), name)
  2518. if err == nil {
  2519. providerLog(logger.LevelDebug, "quota updated for folder %#v, files increment: %v size increment: %v is reset? %v",
  2520. name, filesAdd, sizeAdd, reset)
  2521. } else {
  2522. providerLog(logger.LevelWarn, "error updating quota for folder %#v: %v", name, err)
  2523. }
  2524. return err
  2525. }
  2526. func sqlCommonGetFolderUsedQuota(mappedPath string, dbHandle *sql.DB) (int, int64, error) {
  2527. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2528. defer cancel()
  2529. q := getQuotaFolderQuery()
  2530. var usedFiles int
  2531. var usedSize int64
  2532. err := dbHandle.QueryRowContext(ctx, q, mappedPath).Scan(&usedSize, &usedFiles)
  2533. if err != nil {
  2534. providerLog(logger.LevelError, "error getting quota for folder: %v, error: %v", mappedPath, err)
  2535. return 0, 0, err
  2536. }
  2537. return usedFiles, usedSize, err
  2538. }
  2539. func getAPIKeyWithRelatedFields(ctx context.Context, apiKey APIKey, dbHandle sqlQuerier) (APIKey, error) {
  2540. var apiKeys []APIKey
  2541. var err error
  2542. scope := APIKeyScopeAdmin
  2543. if apiKey.userID > 0 {
  2544. scope = APIKeyScopeUser
  2545. }
  2546. apiKeys, err = getRelatedValuesForAPIKeys(ctx, []APIKey{apiKey}, dbHandle, scope)
  2547. if err != nil {
  2548. return apiKey, err
  2549. }
  2550. if len(apiKeys) > 0 {
  2551. apiKey = apiKeys[0]
  2552. }
  2553. return apiKey, nil
  2554. }
  2555. func getRelatedValuesForAPIKeys(ctx context.Context, apiKeys []APIKey, dbHandle sqlQuerier, scope APIKeyScope) ([]APIKey, error) {
  2556. if len(apiKeys) == 0 {
  2557. return apiKeys, nil
  2558. }
  2559. values := make(map[int64]string)
  2560. var q string
  2561. if scope == APIKeyScopeUser {
  2562. q = getRelatedUsersForAPIKeysQuery(apiKeys)
  2563. } else {
  2564. q = getRelatedAdminsForAPIKeysQuery(apiKeys)
  2565. }
  2566. rows, err := dbHandle.QueryContext(ctx, q)
  2567. if err != nil {
  2568. return nil, err
  2569. }
  2570. defer rows.Close()
  2571. for rows.Next() {
  2572. var valueID int64
  2573. var valueName string
  2574. err = rows.Scan(&valueID, &valueName)
  2575. if err != nil {
  2576. return apiKeys, err
  2577. }
  2578. values[valueID] = valueName
  2579. }
  2580. err = rows.Err()
  2581. if err != nil {
  2582. return apiKeys, err
  2583. }
  2584. if len(values) == 0 {
  2585. return apiKeys, nil
  2586. }
  2587. for idx := range apiKeys {
  2588. ref := &apiKeys[idx]
  2589. if scope == APIKeyScopeUser {
  2590. ref.User = values[ref.userID]
  2591. } else {
  2592. ref.Admin = values[ref.adminID]
  2593. }
  2594. }
  2595. return apiKeys, nil
  2596. }
  2597. func sqlCommonGetAPIKeyRelatedIDs(apiKey *APIKey) (sql.NullInt64, sql.NullInt64, error) {
  2598. var userID, adminID sql.NullInt64
  2599. if apiKey.User != "" {
  2600. u, err := provider.userExists(apiKey.User, "")
  2601. if err != nil {
  2602. return userID, adminID, util.NewValidationError(fmt.Sprintf("unable to validate user %v", apiKey.User))
  2603. }
  2604. userID.Valid = true
  2605. userID.Int64 = u.ID
  2606. }
  2607. if apiKey.Admin != "" {
  2608. a, err := provider.adminExists(apiKey.Admin)
  2609. if err != nil {
  2610. return userID, adminID, util.NewValidationError(fmt.Sprintf("unable to validate admin %v", apiKey.Admin))
  2611. }
  2612. adminID.Valid = true
  2613. adminID.Int64 = a.ID
  2614. }
  2615. return userID, adminID, nil
  2616. }
  2617. func sqlCommonAddSession(session Session, dbHandle *sql.DB) error {
  2618. if err := session.validate(); err != nil {
  2619. return err
  2620. }
  2621. data, err := json.Marshal(session.Data)
  2622. if err != nil {
  2623. return err
  2624. }
  2625. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2626. defer cancel()
  2627. q := getAddSessionQuery()
  2628. _, err = dbHandle.ExecContext(ctx, q, session.Key, data, session.Type, session.Timestamp)
  2629. return err
  2630. }
  2631. func sqlCommonGetSession(key string, dbHandle sqlQuerier) (Session, error) {
  2632. var session Session
  2633. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2634. defer cancel()
  2635. q := getSessionQuery()
  2636. var data []byte // type hint, some driver will use string instead of []byte if the type is any
  2637. err := dbHandle.QueryRowContext(ctx, q, key).Scan(&session.Key, &data, &session.Type, &session.Timestamp)
  2638. if err != nil {
  2639. return session, err
  2640. }
  2641. session.Data = data
  2642. return session, nil
  2643. }
  2644. func sqlCommonDeleteSession(key string, dbHandle *sql.DB) error {
  2645. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2646. defer cancel()
  2647. q := getDeleteSessionQuery()
  2648. res, err := dbHandle.ExecContext(ctx, q, key)
  2649. if err != nil {
  2650. return err
  2651. }
  2652. return sqlCommonRequireRowAffected(res)
  2653. }
  2654. func sqlCommonCleanupSessions(sessionType SessionType, before int64, dbHandle *sql.DB) error {
  2655. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2656. defer cancel()
  2657. q := getCleanupSessionsQuery()
  2658. _, err := dbHandle.ExecContext(ctx, q, sessionType, before)
  2659. return err
  2660. }
  2661. func getActionsWithRuleNames(ctx context.Context, actions []BaseEventAction, dbHandle sqlQuerier,
  2662. ) ([]BaseEventAction, error) {
  2663. if len(actions) == 0 {
  2664. return actions, nil
  2665. }
  2666. q := getRelatedRulesForActionsQuery(actions)
  2667. rows, err := dbHandle.QueryContext(ctx, q)
  2668. if err != nil {
  2669. return nil, err
  2670. }
  2671. defer rows.Close()
  2672. actionsRules := make(map[int64][]string)
  2673. for rows.Next() {
  2674. var name string
  2675. var actionID int64
  2676. if err = rows.Scan(&actionID, &name); err != nil {
  2677. return nil, err
  2678. }
  2679. actionsRules[actionID] = append(actionsRules[actionID], name)
  2680. }
  2681. err = rows.Err()
  2682. if err != nil {
  2683. return nil, err
  2684. }
  2685. if len(actionsRules) == 0 {
  2686. return actions, nil
  2687. }
  2688. for idx := range actions {
  2689. ref := &actions[idx]
  2690. ref.Rules = actionsRules[ref.ID]
  2691. }
  2692. return actions, nil
  2693. }
  2694. func getRulesWithActions(ctx context.Context, rules []EventRule, dbHandle sqlQuerier) ([]EventRule, error) {
  2695. if len(rules) == 0 {
  2696. return rules, nil
  2697. }
  2698. rulesActions := make(map[int64][]EventAction)
  2699. q := getRelatedActionsForRulesQuery(rules)
  2700. rows, err := dbHandle.QueryContext(ctx, q)
  2701. if err != nil {
  2702. return nil, err
  2703. }
  2704. defer rows.Close()
  2705. for rows.Next() {
  2706. var action EventAction
  2707. var ruleID int64
  2708. var description sql.NullString
  2709. var baseOptions, options []byte
  2710. err = rows.Scan(&action.ID, &action.Name, &description, &action.Type, &baseOptions, &options,
  2711. &action.Order, &ruleID)
  2712. if err != nil {
  2713. return rules, err
  2714. }
  2715. if len(baseOptions) > 0 {
  2716. err = json.Unmarshal(baseOptions, &action.BaseEventAction.Options)
  2717. if err != nil {
  2718. return rules, err
  2719. }
  2720. }
  2721. if len(options) > 0 {
  2722. err = json.Unmarshal(options, &action.Options)
  2723. if err != nil {
  2724. return rules, err
  2725. }
  2726. }
  2727. action.BaseEventAction.Options.SetEmptySecretsIfNil()
  2728. rulesActions[ruleID] = append(rulesActions[ruleID], action)
  2729. }
  2730. err = rows.Err()
  2731. if err != nil {
  2732. return rules, err
  2733. }
  2734. if len(rulesActions) == 0 {
  2735. return rules, nil
  2736. }
  2737. for idx := range rules {
  2738. ref := &rules[idx]
  2739. ref.Actions = rulesActions[ref.ID]
  2740. }
  2741. return rules, nil
  2742. }
  2743. func generateEventRuleActionsMapping(ctx context.Context, rule *EventRule, dbHandle sqlQuerier) error {
  2744. q := getClearRuleActionMappingQuery()
  2745. _, err := dbHandle.ExecContext(ctx, q, rule.Name)
  2746. if err != nil {
  2747. return err
  2748. }
  2749. for _, action := range rule.Actions {
  2750. options, err := json.Marshal(action.Options)
  2751. if err != nil {
  2752. return err
  2753. }
  2754. q = getAddRuleActionMappingQuery()
  2755. _, err = dbHandle.ExecContext(ctx, q, rule.Name, action.Name, action.Order, string(options))
  2756. if err != nil {
  2757. return err
  2758. }
  2759. }
  2760. return nil
  2761. }
  2762. func sqlCommonGetEventActionByName(name string, dbHandle sqlQuerier) (BaseEventAction, error) {
  2763. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2764. defer cancel()
  2765. q := getEventActionByNameQuery()
  2766. row := dbHandle.QueryRowContext(ctx, q, name)
  2767. action, err := getEventActionFromDbRow(row)
  2768. if err != nil {
  2769. return action, err
  2770. }
  2771. actions, err := getActionsWithRuleNames(ctx, []BaseEventAction{action}, dbHandle)
  2772. if err != nil {
  2773. return action, err
  2774. }
  2775. if len(actions) != 1 {
  2776. return action, fmt.Errorf("unable to associate rules with action %q", name)
  2777. }
  2778. return actions[0], nil
  2779. }
  2780. func sqlCommonDumpEventActions(dbHandle sqlQuerier) ([]BaseEventAction, error) {
  2781. actions := make([]BaseEventAction, 0, 10)
  2782. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  2783. defer cancel()
  2784. q := getDumpEventActionsQuery()
  2785. rows, err := dbHandle.QueryContext(ctx, q)
  2786. if err != nil {
  2787. return actions, err
  2788. }
  2789. defer rows.Close()
  2790. for rows.Next() {
  2791. action, err := getEventActionFromDbRow(rows)
  2792. if err != nil {
  2793. return actions, err
  2794. }
  2795. actions = append(actions, action)
  2796. }
  2797. return actions, rows.Err()
  2798. }
  2799. func sqlCommonGetEventActions(limit int, offset int, order string, minimal bool,
  2800. dbHandle sqlQuerier,
  2801. ) ([]BaseEventAction, error) {
  2802. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2803. defer cancel()
  2804. q := getEventsActionsQuery(order, minimal)
  2805. actions := make([]BaseEventAction, 0, limit)
  2806. rows, err := dbHandle.QueryContext(ctx, q, limit, offset)
  2807. if err != nil {
  2808. return actions, err
  2809. }
  2810. defer rows.Close()
  2811. for rows.Next() {
  2812. var action BaseEventAction
  2813. if minimal {
  2814. err = rows.Scan(&action.ID, &action.Name)
  2815. } else {
  2816. action, err = getEventActionFromDbRow(rows)
  2817. }
  2818. if err != nil {
  2819. return actions, err
  2820. }
  2821. actions = append(actions, action)
  2822. }
  2823. err = rows.Err()
  2824. if err != nil {
  2825. return nil, err
  2826. }
  2827. if minimal {
  2828. return actions, nil
  2829. }
  2830. actions, err = getActionsWithRuleNames(ctx, actions, dbHandle)
  2831. if err != nil {
  2832. return nil, err
  2833. }
  2834. for idx := range actions {
  2835. actions[idx].PrepareForRendering()
  2836. }
  2837. return actions, nil
  2838. }
  2839. func sqlCommonAddEventAction(action *BaseEventAction, dbHandle *sql.DB) error {
  2840. if err := action.validate(); err != nil {
  2841. return err
  2842. }
  2843. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2844. defer cancel()
  2845. q := getAddEventActionQuery()
  2846. options, err := json.Marshal(action.Options)
  2847. if err != nil {
  2848. return err
  2849. }
  2850. _, err = dbHandle.ExecContext(ctx, q, action.Name, action.Description, action.Type, string(options))
  2851. return err
  2852. }
  2853. func sqlCommonUpdateEventAction(action *BaseEventAction, dbHandle *sql.DB) error {
  2854. if err := action.validate(); err != nil {
  2855. return err
  2856. }
  2857. options, err := json.Marshal(action.Options)
  2858. if err != nil {
  2859. return err
  2860. }
  2861. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2862. defer cancel()
  2863. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  2864. q := getUpdateEventActionQuery()
  2865. _, err = tx.ExecContext(ctx, q, action.Description, action.Type, string(options), action.Name)
  2866. if err != nil {
  2867. return err
  2868. }
  2869. q = getUpdateRulesTimestampQuery()
  2870. _, err = tx.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), action.ID)
  2871. return err
  2872. })
  2873. }
  2874. func sqlCommonDeleteEventAction(action BaseEventAction, dbHandle *sql.DB) error {
  2875. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2876. defer cancel()
  2877. q := getDeleteEventActionQuery()
  2878. res, err := dbHandle.ExecContext(ctx, q, action.Name)
  2879. if err != nil {
  2880. return err
  2881. }
  2882. return sqlCommonRequireRowAffected(res)
  2883. }
  2884. func sqlCommonGetEventRuleByName(name string, dbHandle sqlQuerier) (EventRule, error) {
  2885. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2886. defer cancel()
  2887. q := getEventRulesByNameQuery()
  2888. row := dbHandle.QueryRowContext(ctx, q, name)
  2889. rule, err := getEventRuleFromDbRow(row)
  2890. if err != nil {
  2891. return rule, err
  2892. }
  2893. rules, err := getRulesWithActions(ctx, []EventRule{rule}, dbHandle)
  2894. if err != nil {
  2895. return rule, err
  2896. }
  2897. if len(rules) != 1 {
  2898. return rule, fmt.Errorf("unable to associate rule %q with actions", name)
  2899. }
  2900. return rules[0], nil
  2901. }
  2902. func sqlCommonDumpEventRules(dbHandle sqlQuerier) ([]EventRule, error) {
  2903. rules := make([]EventRule, 0, 10)
  2904. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  2905. defer cancel()
  2906. q := getDumpEventRulesQuery()
  2907. rows, err := dbHandle.QueryContext(ctx, q)
  2908. if err != nil {
  2909. return rules, err
  2910. }
  2911. defer rows.Close()
  2912. for rows.Next() {
  2913. rule, err := getEventRuleFromDbRow(rows)
  2914. if err != nil {
  2915. return rules, err
  2916. }
  2917. rules = append(rules, rule)
  2918. }
  2919. err = rows.Err()
  2920. if err != nil {
  2921. return rules, err
  2922. }
  2923. return getRulesWithActions(ctx, rules, dbHandle)
  2924. }
  2925. func sqlCommonGetRecentlyUpdatedRules(after int64, dbHandle sqlQuerier) ([]EventRule, error) {
  2926. rules := make([]EventRule, 0, 10)
  2927. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  2928. defer cancel()
  2929. q := getRecentlyUpdatedRulesQuery()
  2930. rows, err := dbHandle.QueryContext(ctx, q, after)
  2931. if err != nil {
  2932. return rules, err
  2933. }
  2934. defer rows.Close()
  2935. for rows.Next() {
  2936. rule, err := getEventRuleFromDbRow(rows)
  2937. if err != nil {
  2938. return rules, err
  2939. }
  2940. rules = append(rules, rule)
  2941. }
  2942. err = rows.Err()
  2943. if err != nil {
  2944. return rules, err
  2945. }
  2946. return getRulesWithActions(ctx, rules, dbHandle)
  2947. }
  2948. func sqlCommonGetEventRules(limit int, offset int, order string, dbHandle sqlQuerier) ([]EventRule, error) {
  2949. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2950. defer cancel()
  2951. q := getEventRulesQuery(order)
  2952. rules := make([]EventRule, 0, limit)
  2953. rows, err := dbHandle.QueryContext(ctx, q, limit, offset)
  2954. if err != nil {
  2955. return rules, err
  2956. }
  2957. defer rows.Close()
  2958. for rows.Next() {
  2959. rule, err := getEventRuleFromDbRow(rows)
  2960. if err != nil {
  2961. return rules, err
  2962. }
  2963. rules = append(rules, rule)
  2964. }
  2965. err = rows.Err()
  2966. if err != nil {
  2967. return rules, err
  2968. }
  2969. rules, err = getRulesWithActions(ctx, rules, dbHandle)
  2970. if err != nil {
  2971. return rules, err
  2972. }
  2973. for idx := range rules {
  2974. rules[idx].PrepareForRendering()
  2975. }
  2976. return rules, nil
  2977. }
  2978. func sqlCommonAddEventRule(rule *EventRule, dbHandle *sql.DB) error {
  2979. if err := rule.validate(); err != nil {
  2980. return err
  2981. }
  2982. conditions, err := json.Marshal(rule.Conditions)
  2983. if err != nil {
  2984. return err
  2985. }
  2986. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  2987. defer cancel()
  2988. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  2989. if config.IsShared == 1 {
  2990. _, err := tx.ExecContext(ctx, getRemoveSoftDeletedRuleQuery(), rule.Name)
  2991. if err != nil {
  2992. return err
  2993. }
  2994. }
  2995. q := getAddEventRuleQuery()
  2996. _, err := tx.ExecContext(ctx, q, rule.Name, rule.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
  2997. util.GetTimeAsMsSinceEpoch(time.Now()), rule.Trigger, string(conditions))
  2998. if err != nil {
  2999. return err
  3000. }
  3001. return generateEventRuleActionsMapping(ctx, rule, tx)
  3002. })
  3003. }
  3004. func sqlCommonUpdateEventRule(rule *EventRule, dbHandle *sql.DB) error {
  3005. if err := rule.validate(); err != nil {
  3006. return err
  3007. }
  3008. conditions, err := json.Marshal(rule.Conditions)
  3009. if err != nil {
  3010. return err
  3011. }
  3012. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  3013. defer cancel()
  3014. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  3015. q := getUpdateEventRuleQuery()
  3016. _, err := tx.ExecContext(ctx, q, rule.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
  3017. rule.Trigger, string(conditions), rule.Name)
  3018. if err != nil {
  3019. return err
  3020. }
  3021. return generateEventRuleActionsMapping(ctx, rule, tx)
  3022. })
  3023. }
  3024. func sqlCommonDeleteEventRule(rule EventRule, softDelete bool, dbHandle *sql.DB) error {
  3025. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  3026. defer cancel()
  3027. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  3028. if softDelete {
  3029. q := getClearRuleActionMappingQuery()
  3030. _, err := tx.ExecContext(ctx, q, rule.Name)
  3031. if err != nil {
  3032. return err
  3033. }
  3034. }
  3035. q := getDeleteEventRuleQuery(softDelete)
  3036. if softDelete {
  3037. ts := util.GetTimeAsMsSinceEpoch(time.Now())
  3038. res, err := tx.ExecContext(ctx, q, ts, ts, rule.Name)
  3039. if err != nil {
  3040. return err
  3041. }
  3042. return sqlCommonRequireRowAffected(res)
  3043. }
  3044. res, err := tx.ExecContext(ctx, q, rule.Name)
  3045. if err != nil {
  3046. return err
  3047. }
  3048. if err = sqlCommonRequireRowAffected(res); err != nil {
  3049. return err
  3050. }
  3051. return sqlCommonDeleteTask(rule.Name, tx)
  3052. })
  3053. }
  3054. func sqlCommonGetTaskByName(name string, dbHandle sqlQuerier) (Task, error) {
  3055. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  3056. defer cancel()
  3057. task := Task{
  3058. Name: name,
  3059. }
  3060. q := getTaskByNameQuery()
  3061. row := dbHandle.QueryRowContext(ctx, q, name)
  3062. err := row.Scan(&task.UpdateAt, &task.Version)
  3063. if err != nil {
  3064. if errors.Is(err, sql.ErrNoRows) {
  3065. return task, util.NewRecordNotFoundError(err.Error())
  3066. }
  3067. }
  3068. return task, err
  3069. }
  3070. func sqlCommonAddTask(name string, dbHandle *sql.DB) error {
  3071. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  3072. defer cancel()
  3073. q := getAddTaskQuery()
  3074. _, err := dbHandle.ExecContext(ctx, q, name, util.GetTimeAsMsSinceEpoch(time.Now()))
  3075. return err
  3076. }
  3077. func sqlCommonUpdateTask(name string, version int64, dbHandle *sql.DB) error {
  3078. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  3079. defer cancel()
  3080. q := getUpdateTaskQuery()
  3081. res, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), name, version)
  3082. if err != nil {
  3083. return err
  3084. }
  3085. return sqlCommonRequireRowAffected(res)
  3086. }
  3087. func sqlCommonUpdateTaskTimestamp(name string, dbHandle *sql.DB) error {
  3088. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  3089. defer cancel()
  3090. q := getUpdateTaskTimestampQuery()
  3091. res, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), name)
  3092. if err != nil {
  3093. return err
  3094. }
  3095. return sqlCommonRequireRowAffected(res)
  3096. }
  3097. func sqlCommonDeleteTask(name string, dbHandle sqlQuerier) error {
  3098. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  3099. defer cancel()
  3100. q := getDeleteTaskQuery()
  3101. _, err := dbHandle.ExecContext(ctx, q, name)
  3102. return err
  3103. }
  3104. func sqlCommonAddNode(dbHandle *sql.DB) error {
  3105. if err := currentNode.validate(); err != nil {
  3106. return fmt.Errorf("unable to register cluster node: %w", err)
  3107. }
  3108. data, err := json.Marshal(currentNode.Data)
  3109. if err != nil {
  3110. return err
  3111. }
  3112. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  3113. defer cancel()
  3114. q := getAddNodeQuery()
  3115. _, err = dbHandle.ExecContext(ctx, q, currentNode.Name, string(data), util.GetTimeAsMsSinceEpoch(time.Now()),
  3116. util.GetTimeAsMsSinceEpoch(time.Now()))
  3117. if err != nil {
  3118. return fmt.Errorf("unable to register cluster node: %w", err)
  3119. }
  3120. providerLog(logger.LevelInfo, "registered as cluster node %q, port: %d, proto: %s",
  3121. currentNode.Name, currentNode.Data.Port, currentNode.Data.Proto)
  3122. return nil
  3123. }
  3124. func sqlCommonGetNodeByName(name string, dbHandle *sql.DB) (Node, error) {
  3125. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  3126. defer cancel()
  3127. var data []byte
  3128. var node Node
  3129. q := getNodeByNameQuery()
  3130. row := dbHandle.QueryRowContext(ctx, q, name, util.GetTimeAsMsSinceEpoch(time.Now().Add(activeNodeTimeDiff)))
  3131. err := row.Scan(&node.Name, &data, &node.CreatedAt, &node.UpdatedAt)
  3132. if err != nil {
  3133. if errors.Is(err, sql.ErrNoRows) {
  3134. return node, util.NewRecordNotFoundError(err.Error())
  3135. }
  3136. return node, err
  3137. }
  3138. err = json.Unmarshal(data, &node.Data)
  3139. return node, err
  3140. }
  3141. func sqlCommonGetNodes(dbHandle *sql.DB) ([]Node, error) {
  3142. var nodes []Node
  3143. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  3144. defer cancel()
  3145. q := getNodesQuery()
  3146. rows, err := dbHandle.QueryContext(ctx, q, currentNode.Name,
  3147. util.GetTimeAsMsSinceEpoch(time.Now().Add(activeNodeTimeDiff)))
  3148. if err != nil {
  3149. return nodes, err
  3150. }
  3151. defer rows.Close()
  3152. for rows.Next() {
  3153. var node Node
  3154. var data []byte
  3155. err = rows.Scan(&node.Name, &data, &node.CreatedAt, &node.UpdatedAt)
  3156. if err != nil {
  3157. return nodes, err
  3158. }
  3159. err = json.Unmarshal(data, &node.Data)
  3160. if err != nil {
  3161. return nodes, err
  3162. }
  3163. nodes = append(nodes, node)
  3164. }
  3165. return nodes, rows.Err()
  3166. }
  3167. func sqlCommonUpdateNodeTimestamp(dbHandle *sql.DB) error {
  3168. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  3169. defer cancel()
  3170. q := getUpdateNodeTimestampQuery()
  3171. res, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), currentNode.Name)
  3172. if err != nil {
  3173. return err
  3174. }
  3175. return sqlCommonRequireRowAffected(res)
  3176. }
  3177. func sqlCommonCleanupNodes(dbHandle *sql.DB) error {
  3178. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  3179. defer cancel()
  3180. q := getCleanupNodesQuery()
  3181. _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now().Add(10*activeNodeTimeDiff)))
  3182. return err
  3183. }
  3184. func sqlCommonGetDatabaseVersion(dbHandle sqlQuerier, showInitWarn bool) (schemaVersion, error) {
  3185. var result schemaVersion
  3186. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  3187. defer cancel()
  3188. q := getDatabaseVersionQuery()
  3189. stmt, err := dbHandle.PrepareContext(ctx, q)
  3190. if err != nil {
  3191. providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
  3192. if showInitWarn && strings.Contains(err.Error(), sqlTableSchemaVersion) {
  3193. logger.WarnToConsole("database query error, did you forgot to run the \"initprovider\" command?")
  3194. }
  3195. return result, err
  3196. }
  3197. defer stmt.Close()
  3198. row := stmt.QueryRowContext(ctx)
  3199. err = row.Scan(&result.Version)
  3200. return result, err
  3201. }
  3202. func sqlCommonRequireRowAffected(res sql.Result) error {
  3203. // MariaDB/MySQL returns 0 rows affected for updates that don't change anything
  3204. // so we don't check rows affected for updates
  3205. affected, err := res.RowsAffected()
  3206. if err == nil && affected == 0 {
  3207. return util.NewRecordNotFoundError(sql.ErrNoRows.Error())
  3208. }
  3209. return nil
  3210. }
  3211. func sqlCommonUpdateDatabaseVersion(ctx context.Context, dbHandle sqlQuerier, version int) error {
  3212. q := getUpdateDBVersionQuery()
  3213. _, err := dbHandle.ExecContext(ctx, q, version)
  3214. return err
  3215. }
  3216. func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sqlQueries []string, newVersion int, isUp bool) error {
  3217. if err := sqlAcquireLock(dbHandle); err != nil {
  3218. return err
  3219. }
  3220. defer sqlReleaseLock(dbHandle)
  3221. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  3222. defer cancel()
  3223. if newVersion > 0 {
  3224. currentVersion, err := sqlCommonGetDatabaseVersion(dbHandle, false)
  3225. if err == nil {
  3226. if (isUp && currentVersion.Version >= newVersion) || (!isUp && currentVersion.Version <= newVersion) {
  3227. providerLog(logger.LevelInfo, "current schema version: %v, requested: %v, did you execute simultaneous migrations?",
  3228. currentVersion.Version, newVersion)
  3229. return nil
  3230. }
  3231. }
  3232. }
  3233. return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
  3234. for _, q := range sqlQueries {
  3235. if strings.TrimSpace(q) == "" {
  3236. continue
  3237. }
  3238. _, err := tx.ExecContext(ctx, q)
  3239. if err != nil {
  3240. return err
  3241. }
  3242. }
  3243. if newVersion == 0 {
  3244. return nil
  3245. }
  3246. return sqlCommonUpdateDatabaseVersion(ctx, tx, newVersion)
  3247. })
  3248. }
  3249. func sqlAcquireLock(dbHandle *sql.DB) error {
  3250. ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
  3251. defer cancel()
  3252. switch config.Driver {
  3253. case PGSQLDataProviderName:
  3254. _, err := dbHandle.ExecContext(ctx, `SELECT pg_advisory_lock(101,1)`)
  3255. if err != nil {
  3256. return fmt.Errorf("unable to get advisory lock: %w", err)
  3257. }
  3258. providerLog(logger.LevelInfo, "acquired database lock")
  3259. case MySQLDataProviderName:
  3260. var lockResult sql.NullInt64
  3261. err := dbHandle.QueryRowContext(ctx, `SELECT GET_LOCK('sftpgo.migration',30)`).Scan(&lockResult)
  3262. if err != nil {
  3263. return fmt.Errorf("unable to get lock: %w", err)
  3264. }
  3265. if !lockResult.Valid {
  3266. return errors.New("unable to get lock: null value returned")
  3267. }
  3268. if lockResult.Int64 != 1 {
  3269. return fmt.Errorf("unable to get lock, result: %d", lockResult.Int64)
  3270. }
  3271. providerLog(logger.LevelInfo, "acquired database lock")
  3272. }
  3273. return nil
  3274. }
  3275. func sqlReleaseLock(dbHandle *sql.DB) {
  3276. ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
  3277. defer cancel()
  3278. switch config.Driver {
  3279. case PGSQLDataProviderName:
  3280. _, err := dbHandle.ExecContext(ctx, `SELECT pg_advisory_unlock(101,1)`)
  3281. if err != nil {
  3282. providerLog(logger.LevelWarn, "unable to release lock: %v", err)
  3283. } else {
  3284. providerLog(logger.LevelInfo, "released database lock")
  3285. }
  3286. case MySQLDataProviderName:
  3287. _, err := dbHandle.ExecContext(ctx, `SELECT RELEASE_LOCK('sftpgo.migration')`)
  3288. if err != nil {
  3289. providerLog(logger.LevelWarn, "unable to release lock: %v", err)
  3290. } else {
  3291. providerLog(logger.LevelInfo, "released database lock")
  3292. }
  3293. }
  3294. }
  3295. func sqlCommonExecuteTx(ctx context.Context, dbHandle *sql.DB, txFn func(*sql.Tx) error) error {
  3296. if config.Driver == CockroachDataProviderName {
  3297. return crdb.ExecuteTx(ctx, dbHandle, nil, txFn)
  3298. }
  3299. tx, err := dbHandle.BeginTx(ctx, nil)
  3300. if err != nil {
  3301. return err
  3302. }
  3303. err = txFn(tx)
  3304. if err != nil {
  3305. // we don't change the returned error
  3306. tx.Rollback() //nolint:errcheck
  3307. return err
  3308. }
  3309. return tx.Commit()
  3310. }