|
| 1 | +package scheduler |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "sync" |
| 6 | + "time" |
| 7 | + |
| 8 | + "github.com/apache/arrow/go/v16/arrow/array" |
| 9 | + "github.com/apache/arrow/go/v16/arrow/memory" |
| 10 | + "github.com/cloudquery/plugin-sdk/v4/message" |
| 11 | + "github.com/cloudquery/plugin-sdk/v4/scalar" |
| 12 | + "github.com/cloudquery/plugin-sdk/v4/schema" |
| 13 | + "github.com/cloudquery/plugin-sdk/v4/writers" |
| 14 | + "github.com/rs/zerolog" |
| 15 | +) |
| 16 | + |
| 17 | +type ( |
| 18 | + BatchSettings struct { |
| 19 | + MaxRows int |
| 20 | + Timeout time.Duration |
| 21 | + } |
| 22 | + |
| 23 | + BatchOption func(settings *BatchSettings) |
| 24 | +) |
| 25 | + |
| 26 | +func WithBatchOptions(options ...BatchOption) Option { |
| 27 | + return func(s *Scheduler) { |
| 28 | + if s.batchSettings == nil { |
| 29 | + s.batchSettings = new(BatchSettings) |
| 30 | + } |
| 31 | + for _, o := range options { |
| 32 | + o(s.batchSettings) |
| 33 | + } |
| 34 | + } |
| 35 | +} |
| 36 | + |
| 37 | +func WithBatchMaxRows(rows int) BatchOption { |
| 38 | + return func(s *BatchSettings) { |
| 39 | + s.MaxRows = rows |
| 40 | + } |
| 41 | +} |
| 42 | + |
| 43 | +func WithBatchTimeout(timeout time.Duration) BatchOption { |
| 44 | + return func(s *BatchSettings) { |
| 45 | + s.Timeout = timeout |
| 46 | + } |
| 47 | +} |
| 48 | + |
| 49 | +func (s *BatchSettings) getBatcher(ctx context.Context, res chan<- message.SyncMessage, logger zerolog.Logger) batcherInterface { |
| 50 | + if s.Timeout > 0 && s.MaxRows > 1 { |
| 51 | + return &batcher{ |
| 52 | + done: ctx.Done(), |
| 53 | + res: res, |
| 54 | + maxRows: s.MaxRows, |
| 55 | + timeout: s.Timeout, |
| 56 | + logger: logger.With().Int("max_rows", s.MaxRows).Dur("timeout", s.Timeout).Logger(), |
| 57 | + } |
| 58 | + } |
| 59 | + |
| 60 | + return &nopBatcher{res: res} |
| 61 | +} |
| 62 | + |
| 63 | +type batcherInterface interface { |
| 64 | + process(res *schema.Resource) |
| 65 | + close() |
| 66 | +} |
| 67 | + |
| 68 | +type nopBatcher struct { |
| 69 | + res chan<- message.SyncMessage |
| 70 | +} |
| 71 | + |
| 72 | +func (n *nopBatcher) process(resource *schema.Resource) { |
| 73 | + n.res <- &message.SyncInsert{Record: resource.GetValues().ToArrowRecord(resource.Table.ToArrowSchema())} |
| 74 | +} |
| 75 | + |
| 76 | +func (*nopBatcher) close() {} |
| 77 | + |
| 78 | +var _ batcherInterface = (*nopBatcher)(nil) |
| 79 | + |
| 80 | +type batcher struct { |
| 81 | + done <-chan struct{} |
| 82 | + |
| 83 | + res chan<- message.SyncMessage |
| 84 | + |
| 85 | + maxRows int |
| 86 | + timeout time.Duration |
| 87 | + |
| 88 | + // using sync primitives by value here implies that batcher is to be used by pointer only |
| 89 | + // workers is a sync.Map rather than a map + mutex pair |
| 90 | + // because worker allocation & lookup falls into one of the sync.Map use-cases, |
| 91 | + // namely, ever-growing cache (write once, read many times). |
| 92 | + workers sync.Map // k = table name, v = *worker |
| 93 | + wg sync.WaitGroup |
| 94 | + |
| 95 | + logger zerolog.Logger |
| 96 | +} |
| 97 | + |
| 98 | +type worker struct { |
| 99 | + ch chan *schema.Resource |
| 100 | + flush chan chan struct{} |
| 101 | + curRows, maxRows int |
| 102 | + builder *array.RecordBuilder // we can reuse that |
| 103 | + res chan<- message.SyncMessage |
| 104 | + logger zerolog.Logger |
| 105 | +} |
| 106 | + |
| 107 | +// send must be called on len(rows) > 0 |
| 108 | +func (w *worker) send() { |
| 109 | + w.logger.Debug().Int("current_rows", w.curRows).Msg("send") |
| 110 | + w.res <- &message.SyncInsert{Record: w.builder.NewRecord()} |
| 111 | + // we need to reserve here as NewRecord (& underlying NewArray calls) reset the memory |
| 112 | + w.builder.Reserve(w.maxRows) |
| 113 | + w.curRows = 0 // reset |
| 114 | +} |
| 115 | + |
| 116 | +func (w *worker) work(done <-chan struct{}, timeout time.Duration) { |
| 117 | + ticker := writers.NewTicker(timeout) |
| 118 | + defer ticker.Stop() |
| 119 | + tickerCh := ticker.Chan() |
| 120 | + |
| 121 | + for { |
| 122 | + select { |
| 123 | + case r, ok := <-w.ch: |
| 124 | + if !ok { |
| 125 | + if w.curRows > 0 { |
| 126 | + w.send() |
| 127 | + } |
| 128 | + return |
| 129 | + } |
| 130 | + |
| 131 | + // append to builder |
| 132 | + scalar.AppendToRecordBuilder(w.builder, r.GetValues()) |
| 133 | + w.curRows++ |
| 134 | + // check if we need to flush |
| 135 | + if w.maxRows > 0 && w.curRows == w.maxRows { |
| 136 | + w.send() |
| 137 | + ticker.Reset(timeout) |
| 138 | + } |
| 139 | + |
| 140 | + case <-tickerCh: |
| 141 | + if w.curRows > 0 { |
| 142 | + w.send() |
| 143 | + } |
| 144 | + |
| 145 | + case ch := <-w.flush: |
| 146 | + if w.curRows > 0 { |
| 147 | + w.send() |
| 148 | + ticker.Reset(timeout) |
| 149 | + } |
| 150 | + close(ch) |
| 151 | + |
| 152 | + case <-done: |
| 153 | + // this means the request was cancelled |
| 154 | + return // after this NO other call will succeed |
| 155 | + } |
| 156 | + } |
| 157 | +} |
| 158 | + |
| 159 | +func (b *batcher) process(res *schema.Resource) { |
| 160 | + table := res.Table |
| 161 | + // already running worker |
| 162 | + v, loaded := b.workers.Load(table.Name) |
| 163 | + if loaded { |
| 164 | + v.(*worker).ch <- res |
| 165 | + return |
| 166 | + } |
| 167 | + |
| 168 | + // we alloc only ch here, as it may be needed right away |
| 169 | + // for instance, if another goroutine will get the value allocated by us |
| 170 | + wr := &worker{ch: make(chan *schema.Resource, 5)} // 5 is quite enough |
| 171 | + v, loaded = b.workers.LoadOrStore(table.Name, wr) |
| 172 | + if loaded { |
| 173 | + // means that the worker was already in tne sync.Map, so we just discard the wr value |
| 174 | + close(wr.ch) // for GC |
| 175 | + v.(*worker).ch <- res // send res to the already allocated worker |
| 176 | + return |
| 177 | + } |
| 178 | + |
| 179 | + // fill in the required data |
| 180 | + // start wr |
| 181 | + b.wg.Add(1) |
| 182 | + go func() { |
| 183 | + defer b.wg.Done() |
| 184 | + |
| 185 | + // fill in the worker fields |
| 186 | + wr.flush = make(chan chan struct{}) |
| 187 | + wr.maxRows = b.maxRows |
| 188 | + wr.builder = array.NewRecordBuilder(memory.DefaultAllocator, table.ToArrowSchema()) |
| 189 | + wr.res = b.res |
| 190 | + wr.builder.Reserve(b.maxRows) |
| 191 | + wr.logger = b.logger.With().Str("table", table.Name).Logger() |
| 192 | + |
| 193 | + // start processing |
| 194 | + wr.work(b.done, b.timeout) |
| 195 | + }() |
| 196 | + |
| 197 | + wr.ch <- res |
| 198 | +} |
| 199 | + |
| 200 | +func (b *batcher) close() { |
| 201 | + b.workers.Range(func(_, v any) bool { |
| 202 | + close(v.(*worker).ch) |
| 203 | + return true |
| 204 | + }) |
| 205 | + b.wg.Wait() |
| 206 | +} |
0 commit comments