txn.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. package memdb
  2. import (
  3. "bytes"
  4. "fmt"
  5. "strings"
  6. "sync/atomic"
  7. "unsafe"
  8. "github.com/hashicorp/go-immutable-radix"
  9. )
  10. const (
  11. id = "id"
  12. )
  13. // tableIndex is a tuple of (Table, Index) used for lookups
  14. type tableIndex struct {
  15. Table string
  16. Index string
  17. }
  18. // Txn is a transaction against a MemDB.
  19. // This can be a read or write transaction.
  20. type Txn struct {
  21. db *MemDB
  22. write bool
  23. rootTxn *iradix.Txn
  24. after []func()
  25. modified map[tableIndex]*iradix.Txn
  26. }
  27. // readableIndex returns a transaction usable for reading the given
  28. // index in a table. If a write transaction is in progress, we may need
  29. // to use an existing modified txn.
  30. func (txn *Txn) readableIndex(table, index string) *iradix.Txn {
  31. // Look for existing transaction
  32. if txn.write && txn.modified != nil {
  33. key := tableIndex{table, index}
  34. exist, ok := txn.modified[key]
  35. if ok {
  36. return exist
  37. }
  38. }
  39. // Create a read transaction
  40. path := indexPath(table, index)
  41. raw, _ := txn.rootTxn.Get(path)
  42. indexTxn := raw.(*iradix.Tree).Txn()
  43. return indexTxn
  44. }
  45. // writableIndex returns a transaction usable for modifying the
  46. // given index in a table.
  47. func (txn *Txn) writableIndex(table, index string) *iradix.Txn {
  48. if txn.modified == nil {
  49. txn.modified = make(map[tableIndex]*iradix.Txn)
  50. }
  51. // Look for existing transaction
  52. key := tableIndex{table, index}
  53. exist, ok := txn.modified[key]
  54. if ok {
  55. return exist
  56. }
  57. // Start a new transaction
  58. path := indexPath(table, index)
  59. raw, _ := txn.rootTxn.Get(path)
  60. indexTxn := raw.(*iradix.Tree).Txn()
  61. // Keep this open for the duration of the txn
  62. txn.modified[key] = indexTxn
  63. return indexTxn
  64. }
  65. // Abort is used to cancel this transaction.
  66. // This is a noop for read transactions.
  67. func (txn *Txn) Abort() {
  68. // Noop for a read transaction
  69. if !txn.write {
  70. return
  71. }
  72. // Check if already aborted or committed
  73. if txn.rootTxn == nil {
  74. return
  75. }
  76. // Clear the txn
  77. txn.rootTxn = nil
  78. txn.modified = nil
  79. // Release the writer lock since this is invalid
  80. txn.db.writer.Unlock()
  81. }
  82. // Commit is used to finalize this transaction.
  83. // This is a noop for read transactions.
  84. func (txn *Txn) Commit() {
  85. // Noop for a read transaction
  86. if !txn.write {
  87. return
  88. }
  89. // Check if already aborted or committed
  90. if txn.rootTxn == nil {
  91. return
  92. }
  93. // Commit each sub-transaction scoped to (table, index)
  94. for key, subTxn := range txn.modified {
  95. path := indexPath(key.Table, key.Index)
  96. final := subTxn.Commit()
  97. txn.rootTxn.Insert(path, final)
  98. }
  99. // Update the root of the DB
  100. newRoot := txn.rootTxn.Commit()
  101. atomic.StorePointer(&txn.db.root, unsafe.Pointer(newRoot))
  102. // Clear the txn
  103. txn.rootTxn = nil
  104. txn.modified = nil
  105. // Release the writer lock since this is invalid
  106. txn.db.writer.Unlock()
  107. // Run the deferred functions, if any
  108. for i := len(txn.after); i > 0; i-- {
  109. fn := txn.after[i-1]
  110. fn()
  111. }
  112. }
  113. // Insert is used to add or update an object into the given table
  114. func (txn *Txn) Insert(table string, obj interface{}) error {
  115. if !txn.write {
  116. return fmt.Errorf("cannot insert in read-only transaction")
  117. }
  118. // Get the table schema
  119. tableSchema, ok := txn.db.schema.Tables[table]
  120. if !ok {
  121. return fmt.Errorf("invalid table '%s'", table)
  122. }
  123. // Get the primary ID of the object
  124. idSchema := tableSchema.Indexes[id]
  125. idIndexer := idSchema.Indexer.(SingleIndexer)
  126. ok, idVal, err := idIndexer.FromObject(obj)
  127. if err != nil {
  128. return fmt.Errorf("failed to build primary index: %v", err)
  129. }
  130. if !ok {
  131. return fmt.Errorf("object missing primary index")
  132. }
  133. // Lookup the object by ID first, to see if this is an update
  134. idTxn := txn.writableIndex(table, id)
  135. existing, update := idTxn.Get(idVal)
  136. // On an update, there is an existing object with the given
  137. // primary ID. We do the update by deleting the current object
  138. // and inserting the new object.
  139. for name, indexSchema := range tableSchema.Indexes {
  140. indexTxn := txn.writableIndex(table, name)
  141. // Determine the new index value
  142. var (
  143. ok bool
  144. vals [][]byte
  145. err error
  146. )
  147. switch indexer := indexSchema.Indexer.(type) {
  148. case SingleIndexer:
  149. var val []byte
  150. ok, val, err = indexer.FromObject(obj)
  151. vals = [][]byte{val}
  152. case MultiIndexer:
  153. ok, vals, err = indexer.FromObject(obj)
  154. }
  155. if err != nil {
  156. return fmt.Errorf("failed to build index '%s': %v", name, err)
  157. }
  158. // Handle non-unique index by computing a unique index.
  159. // This is done by appending the primary key which must
  160. // be unique anyways.
  161. if ok && !indexSchema.Unique {
  162. for i := range vals {
  163. vals[i] = append(vals[i], idVal...)
  164. }
  165. }
  166. // Handle the update by deleting from the index first
  167. if update {
  168. var (
  169. okExist bool
  170. valsExist [][]byte
  171. err error
  172. )
  173. switch indexer := indexSchema.Indexer.(type) {
  174. case SingleIndexer:
  175. var valExist []byte
  176. okExist, valExist, err = indexer.FromObject(existing)
  177. valsExist = [][]byte{valExist}
  178. case MultiIndexer:
  179. okExist, valsExist, err = indexer.FromObject(existing)
  180. }
  181. if err != nil {
  182. return fmt.Errorf("failed to build index '%s': %v", name, err)
  183. }
  184. if okExist {
  185. for i, valExist := range valsExist {
  186. // Handle non-unique index by computing a unique index.
  187. // This is done by appending the primary key which must
  188. // be unique anyways.
  189. if !indexSchema.Unique {
  190. valExist = append(valExist, idVal...)
  191. }
  192. // If we are writing to the same index with the same value,
  193. // we can avoid the delete as the insert will overwrite the
  194. // value anyways.
  195. if i >= len(vals) || !bytes.Equal(valExist, vals[i]) {
  196. indexTxn.Delete(valExist)
  197. }
  198. }
  199. }
  200. }
  201. // If there is no index value, either this is an error or an expected
  202. // case and we can skip updating
  203. if !ok {
  204. if indexSchema.AllowMissing {
  205. continue
  206. } else {
  207. return fmt.Errorf("missing value for index '%s'", name)
  208. }
  209. }
  210. // Update the value of the index
  211. for _, val := range vals {
  212. indexTxn.Insert(val, obj)
  213. }
  214. }
  215. return nil
  216. }
  217. // Delete is used to delete a single object from the given table
  218. // This object must already exist in the table
  219. func (txn *Txn) Delete(table string, obj interface{}) error {
  220. if !txn.write {
  221. return fmt.Errorf("cannot delete in read-only transaction")
  222. }
  223. // Get the table schema
  224. tableSchema, ok := txn.db.schema.Tables[table]
  225. if !ok {
  226. return fmt.Errorf("invalid table '%s'", table)
  227. }
  228. // Get the primary ID of the object
  229. idSchema := tableSchema.Indexes[id]
  230. idIndexer := idSchema.Indexer.(SingleIndexer)
  231. ok, idVal, err := idIndexer.FromObject(obj)
  232. if err != nil {
  233. return fmt.Errorf("failed to build primary index: %v", err)
  234. }
  235. if !ok {
  236. return fmt.Errorf("object missing primary index")
  237. }
  238. // Lookup the object by ID first, check fi we should continue
  239. idTxn := txn.writableIndex(table, id)
  240. existing, ok := idTxn.Get(idVal)
  241. if !ok {
  242. return fmt.Errorf("not found")
  243. }
  244. // Remove the object from all the indexes
  245. for name, indexSchema := range tableSchema.Indexes {
  246. indexTxn := txn.writableIndex(table, name)
  247. // Handle the update by deleting from the index first
  248. var (
  249. ok bool
  250. vals [][]byte
  251. err error
  252. )
  253. switch indexer := indexSchema.Indexer.(type) {
  254. case SingleIndexer:
  255. var val []byte
  256. ok, val, err = indexer.FromObject(existing)
  257. vals = [][]byte{val}
  258. case MultiIndexer:
  259. ok, vals, err = indexer.FromObject(existing)
  260. }
  261. if err != nil {
  262. return fmt.Errorf("failed to build index '%s': %v", name, err)
  263. }
  264. if ok {
  265. // Handle non-unique index by computing a unique index.
  266. // This is done by appending the primary key which must
  267. // be unique anyways.
  268. for _, val := range vals {
  269. if !indexSchema.Unique {
  270. val = append(val, idVal...)
  271. }
  272. indexTxn.Delete(val)
  273. }
  274. }
  275. }
  276. return nil
  277. }
  278. // DeleteAll is used to delete all the objects in a given table
  279. // matching the constraints on the index
  280. func (txn *Txn) DeleteAll(table, index string, args ...interface{}) (int, error) {
  281. if !txn.write {
  282. return 0, fmt.Errorf("cannot delete in read-only transaction")
  283. }
  284. // Get all the objects
  285. iter, err := txn.Get(table, index, args...)
  286. if err != nil {
  287. return 0, err
  288. }
  289. // Put them into a slice so there are no safety concerns while actually
  290. // performing the deletes
  291. var objs []interface{}
  292. for {
  293. obj := iter.Next()
  294. if obj == nil {
  295. break
  296. }
  297. objs = append(objs, obj)
  298. }
  299. // Do the deletes
  300. num := 0
  301. for _, obj := range objs {
  302. if err := txn.Delete(table, obj); err != nil {
  303. return num, err
  304. }
  305. num++
  306. }
  307. return num, nil
  308. }
  309. // First is used to return the first matching object for
  310. // the given constraints on the index
  311. func (txn *Txn) First(table, index string, args ...interface{}) (interface{}, error) {
  312. // Get the index value
  313. indexSchema, val, err := txn.getIndexValue(table, index, args...)
  314. if err != nil {
  315. return nil, err
  316. }
  317. // Get the index itself
  318. indexTxn := txn.readableIndex(table, indexSchema.Name)
  319. // Do an exact lookup
  320. if indexSchema.Unique && val != nil && indexSchema.Name == index {
  321. obj, ok := indexTxn.Get(val)
  322. if !ok {
  323. return nil, nil
  324. }
  325. return obj, nil
  326. }
  327. // Handle non-unique index by using an iterator and getting the first value
  328. iter := indexTxn.Root().Iterator()
  329. iter.SeekPrefix(val)
  330. _, value, _ := iter.Next()
  331. return value, nil
  332. }
  333. // LongestPrefix is used to fetch the longest prefix match for the given
  334. // constraints on the index. Note that this will not work with the memdb
  335. // StringFieldIndex because it adds null terminators which prevent the
  336. // algorithm from correctly finding a match (it will get to right before the
  337. // null and fail to find a leaf node). This should only be used where the prefix
  338. // given is capable of matching indexed entries directly, which typically only
  339. // applies to a custom indexer. See the unit test for an example.
  340. func (txn *Txn) LongestPrefix(table, index string, args ...interface{}) (interface{}, error) {
  341. // Enforce that this only works on prefix indexes.
  342. if !strings.HasSuffix(index, "_prefix") {
  343. return nil, fmt.Errorf("must use '%s_prefix' on index", index)
  344. }
  345. // Get the index value.
  346. indexSchema, val, err := txn.getIndexValue(table, index, args...)
  347. if err != nil {
  348. return nil, err
  349. }
  350. // This algorithm only makes sense against a unique index, otherwise the
  351. // index keys will have the IDs appended to them.
  352. if !indexSchema.Unique {
  353. return nil, fmt.Errorf("index '%s' is not unique", index)
  354. }
  355. // Find the longest prefix match with the given index.
  356. indexTxn := txn.readableIndex(table, indexSchema.Name)
  357. if _, value, ok := indexTxn.Root().LongestPrefix(val); ok {
  358. return value, nil
  359. }
  360. return nil, nil
  361. }
  362. // getIndexValue is used to get the IndexSchema and the value
  363. // used to scan the index given the parameters. This handles prefix based
  364. // scans when the index has the "_prefix" suffix. The index must support
  365. // prefix iteration.
  366. func (txn *Txn) getIndexValue(table, index string, args ...interface{}) (*IndexSchema, []byte, error) {
  367. // Get the table schema
  368. tableSchema, ok := txn.db.schema.Tables[table]
  369. if !ok {
  370. return nil, nil, fmt.Errorf("invalid table '%s'", table)
  371. }
  372. // Check for a prefix scan
  373. prefixScan := false
  374. if strings.HasSuffix(index, "_prefix") {
  375. index = strings.TrimSuffix(index, "_prefix")
  376. prefixScan = true
  377. }
  378. // Get the index schema
  379. indexSchema, ok := tableSchema.Indexes[index]
  380. if !ok {
  381. return nil, nil, fmt.Errorf("invalid index '%s'", index)
  382. }
  383. // Hot-path for when there are no arguments
  384. if len(args) == 0 {
  385. return indexSchema, nil, nil
  386. }
  387. // Special case the prefix scanning
  388. if prefixScan {
  389. prefixIndexer, ok := indexSchema.Indexer.(PrefixIndexer)
  390. if !ok {
  391. return indexSchema, nil,
  392. fmt.Errorf("index '%s' does not support prefix scanning", index)
  393. }
  394. val, err := prefixIndexer.PrefixFromArgs(args...)
  395. if err != nil {
  396. return indexSchema, nil, fmt.Errorf("index error: %v", err)
  397. }
  398. return indexSchema, val, err
  399. }
  400. // Get the exact match index
  401. val, err := indexSchema.Indexer.FromArgs(args...)
  402. if err != nil {
  403. return indexSchema, nil, fmt.Errorf("index error: %v", err)
  404. }
  405. return indexSchema, val, err
  406. }
  407. // ResultIterator is used to iterate over a list of results
  408. // from a Get query on a table.
  409. type ResultIterator interface {
  410. Next() interface{}
  411. }
  412. // Get is used to construct a ResultIterator over all the
  413. // rows that match the given constraints of an index.
  414. func (txn *Txn) Get(table, index string, args ...interface{}) (ResultIterator, error) {
  415. // Get the index value to scan
  416. indexSchema, val, err := txn.getIndexValue(table, index, args...)
  417. if err != nil {
  418. return nil, err
  419. }
  420. // Get the index itself
  421. indexTxn := txn.readableIndex(table, indexSchema.Name)
  422. indexRoot := indexTxn.Root()
  423. // Get an interator over the index
  424. indexIter := indexRoot.Iterator()
  425. // Seek the iterator to the appropriate sub-set
  426. indexIter.SeekPrefix(val)
  427. // Create an iterator
  428. iter := &radixIterator{
  429. iter: indexIter,
  430. }
  431. return iter, nil
  432. }
  433. // Defer is used to push a new arbitrary function onto a stack which
  434. // gets called when a transaction is committed and finished. Deferred
  435. // functions are called in LIFO order, and only invoked at the end of
  436. // write transactions.
  437. func (txn *Txn) Defer(fn func()) {
  438. txn.after = append(txn.after, fn)
  439. }
  440. // radixIterator is used to wrap an underlying iradix iterator.
  441. // This is much mroe efficient than a sliceIterator as we are not
  442. // materializing the entire view.
  443. type radixIterator struct {
  444. iter *iradix.Iterator
  445. }
  446. func (r *radixIterator) Next() interface{} {
  447. _, value, ok := r.iter.Next()
  448. if !ok {
  449. return nil
  450. }
  451. return value
  452. }