Skip to content

Commit a06aacd

Browse files
committed
revert writers to their original form
1 parent be9c869 commit a06aacd

File tree

3 files changed

+91
-97
lines changed

3 files changed

+91
-97
lines changed

Diff for: writers/batchwriter/batchwriter.go

+24-27
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import (
66
"sync"
77
"time"
88

9-
"github.com/cloudquery/plugin-sdk/v4/internal/batch"
9+
"github.com/apache/arrow/go/v16/arrow/util"
1010
"github.com/cloudquery/plugin-sdk/v4/message"
1111
"github.com/cloudquery/plugin-sdk/v4/schema"
1212
"github.com/cloudquery/plugin-sdk/v4/writers"
@@ -122,7 +122,7 @@ func (w *BatchWriter) Close(context.Context) error {
122122
}
123123

124124
func (w *BatchWriter) worker(ctx context.Context, tableName string, ch <-chan *message.WriteInsert, flush <-chan chan bool) {
125-
limit := batch.CappedAt(w.batchSizeBytes, w.batchSize)
125+
var bytes, rows int64
126126
resources := make([]*message.WriteInsert, 0, w.batchSize) // at least we have 1 row per record
127127

128128
ticker := writers.NewTicker(w.batchTimeout)
@@ -134,50 +134,47 @@ func (w *BatchWriter) worker(ctx context.Context, tableName string, ch <-chan *m
134134
w.flushTable(ctx, tableName, resources)
135135
clear(resources)
136136
resources = resources[:0]
137-
limit.Reset()
137+
bytes, rows = 0, 0
138138
}
139-
140139
for {
141140
select {
142141
case r, ok := <-ch:
143142
if !ok {
144-
if limit.Rows() > 0 {
143+
if rows > 0 {
145144
w.flushTable(ctx, tableName, resources)
146145
}
147146
return
148147
}
149148

150-
if r.Record.NumRows() == 0 {
151-
// skip empty ones
152-
continue
153-
}
154-
155-
add, toFlush, rest := batch.SliceRecord(r.Record, limit)
156-
if add != nil {
157-
resources = append(resources, &message.WriteInsert{Record: add.Record})
158-
}
159-
if len(toFlush) > 0 || rest != nil || limit.ReachedLimit() {
160-
// flush current batch
161-
send()
162-
ticker.Reset(w.batchTimeout)
163-
}
164-
for _, sliceToFlush := range toFlush {
165-
resources = append(resources, &message.WriteInsert{Record: sliceToFlush})
149+
recordRows, recordBytes := r.Record.NumRows(), util.TotalRecordSize(r.Record)
150+
if (w.batchSize > 0 && rows+recordRows > w.batchSize) ||
151+
(w.batchSizeBytes > 0 && bytes+recordBytes > w.batchSizeBytes) {
152+
if rows == 0 {
153+
// New record overflows batch by itself.
154+
// Flush right away.
155+
// TODO: slice
156+
resources = append(resources, r)
157+
send()
158+
ticker.Reset(w.batchTimeout)
159+
continue
160+
}
161+
// rows > 0
166162
send()
167163
ticker.Reset(w.batchTimeout)
168164
}
169-
170-
// set the remainder
171-
if rest != nil {
172-
resources = append(resources, &message.WriteInsert{Record: rest.Record})
165+
if recordRows > 0 {
166+
// only save records with rows
167+
resources = append(resources, r)
168+
rows += recordRows
169+
bytes += recordBytes
173170
}
174171

175172
case <-tickerCh:
176-
if limit.Rows() > 0 {
173+
if rows > 0 {
177174
send()
178175
}
179176
case done := <-flush:
180-
if limit.Rows() > 0 {
177+
if rows > 0 {
181178
send()
182179
ticker.Reset(w.batchTimeout)
183180
}

Diff for: writers/mixedbatchwriter/mixedbatchwriter.go

+20-25
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import (
44
"context"
55
"time"
66

7-
"github.com/cloudquery/plugin-sdk/v4/internal/batch"
7+
"github.com/apache/arrow/go/v16/arrow/util"
88
"github.com/cloudquery/plugin-sdk/v4/message"
99
"github.com/cloudquery/plugin-sdk/v4/writers"
1010
"github.com/rs/zerolog"
@@ -92,7 +92,8 @@ func (w *MixedBatchWriter) Write(ctx context.Context, msgChan <-chan message.Wri
9292
insert := &insertBatchManager{
9393
batch: make([]*message.WriteInsert, 0, w.batchSize),
9494
writeFunc: w.client.InsertBatch,
95-
limit: batch.CappedAt(w.batchSizeBytes, w.batchSize),
95+
maxRows: w.batchSize,
96+
maxBytes: w.batchSizeBytes,
9697
logger: w.logger,
9798
}
9899
deleteStale := &batchManager[message.WriteDeleteStales, *message.WriteDeleteStale]{
@@ -198,53 +199,47 @@ func (m *batchManager[A, T]) flush(ctx context.Context) error {
198199

199200
// special batch manager for insert messages that also keeps track of the total size of the batch
200201
type insertBatchManager struct {
201-
batch []*message.WriteInsert
202-
writeFunc func(ctx context.Context, messages message.WriteInserts) error
203-
limit *batch.Cap
204-
logger zerolog.Logger
202+
batch []*message.WriteInsert
203+
writeFunc func(ctx context.Context, messages message.WriteInserts) error
204+
curRows, maxRows int64
205+
curBytes, maxBytes int64
206+
logger zerolog.Logger
205207
}
206208

207209
func (m *insertBatchManager) append(ctx context.Context, msg *message.WriteInsert) error {
208-
add, toFlush, rest := batch.SliceRecord(msg.Record, m.limit)
209-
if add != nil {
210-
m.batch = append(m.batch, &message.WriteInsert{Record: add.Record})
211-
}
212-
if len(toFlush) > 0 || rest != nil || m.limit.ReachedLimit() {
213-
// flush current batch
214-
if err := m.flush(ctx); err != nil {
215-
return err
216-
}
217-
}
218-
for _, sliceToFlush := range toFlush {
219-
m.batch = append(m.batch, &message.WriteInsert{Record: sliceToFlush})
210+
recordRows, recordBytes := msg.Record.NumRows(), util.TotalRecordSize(msg.Record)
211+
if (m.maxRows > 0 && m.curRows+recordRows > m.maxRows) ||
212+
(m.maxBytes > 0 && m.curBytes+recordBytes > m.maxBytes) {
220213
if err := m.flush(ctx); err != nil {
221214
return err
222215
}
223216
}
224217

225-
// set the remainder
226-
if rest != nil {
227-
m.batch = append(m.batch, &message.WriteInsert{Record: rest.Record})
218+
if recordRows > 0 {
219+
// only save records with rows
220+
m.batch = append(m.batch, msg)
221+
m.curRows += recordRows
222+
m.curBytes += recordBytes
228223
}
229224

230225
return nil
231226
}
232227

233228
func (m *insertBatchManager) flush(ctx context.Context) error {
234-
if m.limit.Rows() == 0 {
229+
if m.curRows == 0 {
235230
// no rows to insert
236231
return nil
237232
}
238233
start := time.Now()
239234
err := m.writeFunc(ctx, m.batch)
240235
if err != nil {
241-
m.logger.Err(err).Int64("len", m.limit.Rows()).Dur("duration", time.Since(start)).Msg("failed to write batch")
236+
m.logger.Err(err).Int64("len", m.curRows).Dur("duration", time.Since(start)).Msg("failed to write batch")
242237
return err
243238
}
244-
m.logger.Debug().Int64("len", m.limit.Rows()).Dur("duration", time.Since(start)).Msg("batch written successfully")
239+
m.logger.Debug().Int64("len", m.curRows).Dur("duration", time.Since(start)).Msg("batch written successfully")
245240

246241
clear(m.batch) // GC can work
247242
m.batch = m.batch[:0]
248-
m.limit.Reset()
243+
m.curRows, m.curBytes = 0, 0
249244
return nil
250245
}

Diff for: writers/streamingbatchwriter/streamingbatchwriter.go

+47-45
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import (
2525
"sync"
2626
"time"
2727

28-
"github.com/cloudquery/plugin-sdk/v4/internal/batch"
28+
"github.com/apache/arrow/go/v16/arrow/util"
2929
"github.com/cloudquery/plugin-sdk/v4/message"
3030
"github.com/cloudquery/plugin-sdk/v4/schema"
3131
"github.com/cloudquery/plugin-sdk/v4/writers"
@@ -233,9 +233,9 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
233233
flush: make(chan chan bool),
234234
errCh: errCh,
235235

236-
limit: batch.CappedAt(0, w.batchSizeRows),
237-
batchTimeout: w.batchTimeout,
238-
tickerFn: w.tickerFn,
236+
batchSizeRows: w.batchSizeRows,
237+
batchTimeout: w.batchTimeout,
238+
tickerFn: w.tickerFn,
239239
}
240240

241241
w.workersWaitGroup.Add(1)
@@ -257,9 +257,9 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
257257
flush: make(chan chan bool),
258258
errCh: errCh,
259259

260-
limit: batch.CappedAt(0, w.batchSizeRows),
261-
batchTimeout: w.batchTimeout,
262-
tickerFn: w.tickerFn,
260+
batchSizeRows: w.batchSizeRows,
261+
batchTimeout: w.batchTimeout,
262+
tickerFn: w.tickerFn,
263263
}
264264

265265
w.workersWaitGroup.Add(1)
@@ -283,9 +283,10 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
283283
flush: make(chan chan bool),
284284
errCh: errCh,
285285

286-
limit: batch.CappedAt(w.batchSizeBytes, w.batchSizeRows),
287-
batchTimeout: w.batchTimeout,
288-
tickerFn: w.tickerFn,
286+
batchSizeRows: w.batchSizeRows,
287+
batchSizeBytes: w.batchSizeBytes,
288+
batchTimeout: w.batchTimeout,
289+
tickerFn: w.tickerFn,
289290
}
290291
w.workersLock.Lock()
291292
wrOld, ok := w.insertWorkers[tableName]
@@ -319,9 +320,9 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
319320
flush: make(chan chan bool),
320321
errCh: errCh,
321322

322-
limit: batch.CappedAt(w.batchSizeBytes, w.batchSizeRows),
323-
batchTimeout: w.batchTimeout,
324-
tickerFn: w.tickerFn,
323+
batchSizeRows: w.batchSizeRows,
324+
batchTimeout: w.batchTimeout,
325+
tickerFn: w.tickerFn,
325326
}
326327

327328
w.workersWaitGroup.Add(1)
@@ -340,17 +341,19 @@ type streamingWorkerManager[T message.WriteMessage] struct {
340341
flush chan chan bool
341342
errCh chan<- error
342343

343-
limit *batch.Cap
344-
batchTimeout time.Duration
345-
tickerFn writers.TickerFunc
344+
batchSizeRows int64
345+
batchSizeBytes int64
346+
batchTimeout time.Duration
347+
tickerFn writers.TickerFunc
346348
}
347349

348350
func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup, tableName string) {
349351
defer wg.Done()
350352
var (
351-
clientCh chan T
352-
clientErrCh chan error
353-
open bool
353+
clientCh chan T
354+
clientErrCh chan error
355+
open bool
356+
sizeBytes, sizeRows int64
354357
)
355358

356359
ensureOpened := func() {
@@ -379,7 +382,7 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
379382
}
380383
}
381384
open = false
382-
s.limit.Reset()
385+
sizeBytes, sizeRows = 0, 0
383386
}
384387
defer closeFlush()
385388

@@ -395,45 +398,44 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
395398
return
396399
}
397400

401+
recordRows := int64(1) // at least 1 row for messages without records
402+
var recordBytes int64
398403
if ins, ok := any(r).(*message.WriteInsert); ok {
399-
add, toFlush, rest := batch.SliceRecord(ins.Record, s.limit)
400-
if add != nil {
401-
ensureOpened()
402-
clientCh <- any(&message.WriteInsert{Record: add.Record}).(T)
403-
}
404-
if len(toFlush) > 0 || rest != nil || s.limit.ReachedLimit() {
405-
// flush current batch
406-
closeFlush()
407-
ticker.Reset(s.batchTimeout)
408-
}
409-
for _, sliceToFlush := range toFlush {
404+
recordBytes = util.TotalRecordSize(ins.Record)
405+
recordRows = ins.Record.NumRows()
406+
}
407+
408+
if (s.batchSizeRows > 0 && sizeRows+recordRows > s.batchSizeRows) ||
409+
(s.batchSizeBytes > 0 && sizeBytes+recordBytes > s.batchSizeBytes) {
410+
if sizeRows == 0 {
411+
// New record overflows batch by itself.
412+
// Flush right away.
413+
// TODO: slice
410414
ensureOpened()
411-
clientCh <- any(&message.WriteInsert{Record: sliceToFlush}).(T)
415+
clientCh <- r
412416
closeFlush()
413417
ticker.Reset(s.batchTimeout)
418+
continue
414419
}
420+
// sizeRows > 0
421+
closeFlush()
422+
ticker.Reset(s.batchTimeout)
423+
}
415424

416-
// set the remainder
417-
if rest != nil {
418-
ensureOpened()
419-
clientCh <- any(&message.WriteInsert{Record: rest.Record}).(T)
420-
}
421-
} else {
425+
if recordRows > 0 {
426+
// only save records with rows
422427
ensureOpened()
423428
clientCh <- r
424-
s.limit.AddRows(1)
425-
if s.limit.ReachedLimit() {
426-
closeFlush()
427-
ticker.Reset(s.batchTimeout)
428-
}
429+
sizeRows += recordRows
430+
sizeBytes += recordBytes
429431
}
430432

431433
case <-tickerCh:
432-
if s.limit.Rows() > 0 {
434+
if sizeRows > 0 {
433435
closeFlush()
434436
}
435437
case done := <-s.flush:
436-
if s.limit.Rows() > 0 {
438+
if sizeRows > 0 {
437439
closeFlush()
438440
ticker.Reset(s.batchTimeout)
439441
}

0 commit comments

Comments
 (0)