Skip to content

Commit 233d988

Browse files
authored
Merge pull request #20 from mrf345/testing
Replace `go-observable` with `StatusObservable`, and optimize
2 parents 8ac1145 + 988db4d commit 233d988

File tree

9 files changed

+120
-73
lines changed

9 files changed

+120
-73
lines changed

go.mod

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ module github.com/mrf345/safelock-cli
33
go 1.22
44

55
require (
6-
github.com/GianlucaGuarini/go-observable v0.0.0-20171228155646-e39e699e0a00
76
github.com/inancgumus/screen v0.0.0-20190314163918-06e984b86ed3
87
github.com/mholt/archiver/v4 v4.0.0-alpha.8
98
github.com/spf13/cobra v1.8.1

go.sum

-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ cloud.google.com/go/storage v1.5.0/go.mod h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0Zeo
1717
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
1818
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
1919
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
20-
github.com/GianlucaGuarini/go-observable v0.0.0-20171228155646-e39e699e0a00 h1:4wp5bMTx9eV6in+ZKiUsyeOqYdp9ooqpw1YWXjwVHJo=
21-
github.com/GianlucaGuarini/go-observable v0.0.0-20171228155646-e39e699e0a00/go.mod h1:2pqNiwoZ8Fj1HBGWyPTXW/iPD332sJzTp3Iy0dIcFMc=
2220
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
2321
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
2422
github.com/bodgit/plumbing v1.2.0 h1:gg4haxoKphLjml+tgnecR4yLBV5zo4HAZGCtAh3xCzM=

safelock/core.go

+11-32
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@ package safelock
22

33
import (
44
"crypto/cipher"
5-
"crypto/rand"
6-
"errors"
75
"fmt"
86
"io"
97

@@ -30,8 +28,7 @@ func newAeadWriter(pwd string, w io.Writer, config EncryptionConfig, errs chan e
3028
errs: errs,
3129
aeadDone: make(chan bool, 2),
3230
}
33-
aw.writeSalt(w)
34-
go aw.loadAead()
31+
go aw.writeSaltAndLoad(w)
3532
return aw
3633
}
3734

@@ -55,20 +52,15 @@ func (aw *aeadWrapper) getAead() cipher.AEAD {
5552
return aw.aead
5653
}
5754

58-
func (aw *aeadWrapper) writeSalt(w io.Writer) {
59-
var err error
60-
61-
aw.salt = make([]byte, aw.config.SaltLength)
62-
63-
if _, err = io.ReadFull(rand.Reader, aw.salt); err != nil {
64-
aw.errs <- fmt.Errorf("failed to create random salt > %w", err)
65-
return
66-
}
55+
func (aw *aeadWrapper) writeSaltAndLoad(w io.Writer) {
56+
aw.salt = (<-aw.config.random)[:aw.config.SaltLength]
6757

68-
if _, err = w.Write(aw.salt); err != nil {
58+
if _, err := w.Write(aw.salt); err != nil {
6959
aw.errs <- fmt.Errorf("failed to write salt > %w", err)
7060
return
7161
}
62+
63+
aw.loadAead()
7264
}
7365

7466
func (aw *aeadWrapper) readSalt(r InputReader) {
@@ -94,11 +86,6 @@ func (aw *aeadWrapper) readSalt(r InputReader) {
9486
func (aw *aeadWrapper) loadAead() {
9587
var err error
9688

97-
if aw.config.SaltLength > len(aw.salt) {
98-
aw.errs <- errors.New("missing salt, most probably race condition")
99-
return
100-
}
101-
10289
key := argon2.IDKey(
10390
aw.pwd,
10491
aw.salt,
@@ -116,27 +103,19 @@ func (aw *aeadWrapper) loadAead() {
116103
aw.aeadDone <- true
117104
}
118105

119-
func (aw *aeadWrapper) encrypt(chunk []byte) (encrypted []byte, err error) {
120-
aead := aw.getAead()
106+
func (aw *aeadWrapper) encrypt(chunk []byte) []byte {
121107
idx := []byte(fmt.Sprintf("%d", aw.counter))
122-
nonce := make([]byte, aead.NonceSize())
123-
124-
if _, err = rand.Read(nonce); err != nil {
125-
aw.errs <- fmt.Errorf("failed to generate nonce > %w", err)
126-
return
127-
}
128-
129-
encrypted = append(nonce, aead.Seal(nil, nonce, chunk, idx)...)
108+
aead := aw.getAead()
109+
nonce := (<-aw.config.random)[:aead.NonceSize()]
130110
aw.counter += 1
131-
132-
return
111+
return append(nonce, aead.Seal(nil, nonce, chunk, idx)...)
133112
}
134113

135114
func (aw *aeadWrapper) decrypt(chunk []byte) (output []byte, err error) {
136115
aead := aw.getAead()
137116

138117
if aead.NonceSize() > len(chunk) {
139-
err = &slErrs.ErrFailedToAuthenticate{Msg: "chunk size size"}
118+
err = &slErrs.ErrFailedToAuthenticate{Msg: "invalid chunk size"}
140119
aw.errs <- err
141120
return
142121
}

safelock/decrypt.go

+6-9
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,18 @@ import (
1616
// and then outputs the content into `outputPath` which must be a valid path to an existing directory
1717
//
1818
// NOTE: `ctx` context is optional you can pass `nil` and the method will handle it
19-
func (sl Safelock) Decrypt(ctx context.Context, input InputReader, outputPath, password string) (err error) {
19+
func (sl *Safelock) Decrypt(ctx context.Context, input InputReader, outputPath, password string) (err error) {
2020
errs := make(chan error)
2121
signals, closeSignals := utils.GetExitSignals()
22+
unSubStatus := sl.StatusObs.Subscribe(sl.logStatus)
2223

2324
if ctx == nil {
2425
ctx = context.Background()
2526
}
2627

27-
sl.StatusObs.
28-
On(StatusUpdate.Str(), sl.logStatus).
29-
Trigger(StatusStart.Str())
30-
31-
defer sl.StatusObs.
32-
Off(StatusUpdate.Str(), sl.logStatus).
33-
Trigger(StatusEnd.Str())
28+
sl.StatusObs.next(StatusItem{Event: StatusStart})
29+
defer sl.StatusObs.next(StatusItem{Event: StatusEnd})
30+
defer unSubStatus()
3431

3532
go func() {
3633
if err = sl.validateDecryptionPaths(outputPath); err != nil {
@@ -68,7 +65,7 @@ func (sl Safelock) Decrypt(ctx context.Context, input InputReader, outputPath, p
6865
err = context.DeadlineExceeded
6966
return
7067
case err = <-errs:
71-
sl.StatusObs.Trigger(StatusError.Str(), err)
68+
sl.StatusObs.next(StatusItem{Event: StatusError, Err: err})
7269
return
7370
case <-signals:
7471
return

safelock/encrypt.go

+8-10
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,20 @@ import (
1616
// outputs into an object `output` that implements [io.Writer] such as [io.File]
1717
//
1818
// NOTE: `ctx` context is optional you can pass `nil` and the method will handle it
19-
func (sl Safelock) Encrypt(ctx context.Context, inputPaths []string, output io.Writer, password string) (err error) {
19+
func (sl *Safelock) Encrypt(ctx context.Context, inputPaths []string, output io.Writer, password string) (err error) {
2020
errs := make(chan error)
21+
go sl.loadRandom(errs)
22+
aead := newAeadWriter(password, output, sl.EncryptionConfig, errs)
2123
signals, closeSignals := utils.GetExitSignals()
24+
unSubStatus := sl.StatusObs.Subscribe(sl.logStatus)
2225

2326
if ctx == nil {
2427
ctx = context.Background()
2528
}
2629

27-
sl.StatusObs.
28-
On(StatusUpdate.Str(), sl.logStatus).
29-
Trigger(StatusStart.Str())
30-
31-
defer sl.StatusObs.
32-
Off(StatusUpdate.Str(), sl.logStatus).
33-
Trigger(StatusEnd.Str())
30+
sl.StatusObs.next(StatusItem{Event: StatusStart})
31+
defer sl.StatusObs.next(StatusItem{Event: StatusEnd})
32+
defer unSubStatus()
3433

3534
go func() {
3635
if err = sl.validateEncryptionInputs(inputPaths, password); err != nil {
@@ -39,7 +38,6 @@ func (sl Safelock) Encrypt(ctx context.Context, inputPaths []string, output io.W
3938
}
4039

4140
ctx, cancel := context.WithCancel(ctx)
42-
aead := newAeadWriter(password, output, sl.EncryptionConfig, errs)
4341
writer := newWriter(password, output, 20.0, cancel, aead)
4442

4543
if err = sl.encryptFiles(ctx, inputPaths, writer); err != nil {
@@ -63,7 +61,7 @@ func (sl Safelock) Encrypt(ctx context.Context, inputPaths []string, output io.W
6361
err = context.DeadlineExceeded
6462
return
6563
case err = <-errs:
66-
sl.StatusObs.Trigger(StatusError.Str(), err)
64+
sl.StatusObs.next(StatusItem{Event: StatusError, Err: err})
6765
return
6866
case <-signals:
6967
return

safelock/events.go

+68-5
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,84 @@
11
package safelock
22

3+
import "sync"
4+
35
// [safelock.Safelock.StatusObs] streaming event keys type
46
type StatusEvent string
57

68
// [safelock.Safelock.StatusObs] streaming event keys
79
const (
8-
StatusStart StatusEvent = "start_status" // encryption/decryption has started (no args)
9-
StatusEnd StatusEvent = "end_status" // encryption/decryption has ended (no args)
10-
StatusUpdate StatusEvent = "update_status" // new status update (status string, percent float64)
11-
StatusError StatusEvent = "error_status" // encryption/decryption failed (error)
10+
StatusStart StatusEvent = "start_status" // encryption/decryption has started
11+
StatusEnd StatusEvent = "end_status" // encryption/decryption has ended
12+
StatusUpdate StatusEvent = "update_status" // new status update
13+
StatusError StatusEvent = "error_status" // encryption/decryption failed
1214
)
1315

1416
// return event key value as string
1517
func (se StatusEvent) Str() string {
1618
return string(se)
1719
}
1820

21+
// item used to communicate status changes
22+
type StatusItem struct {
23+
// status change event key
24+
Event StatusEvent
25+
// completion percent
26+
Percent float64
27+
// optional status change text
28+
Msg string
29+
// optional status change error
30+
Err error
31+
}
32+
33+
// observable like data structure used to stream status changes
34+
type StatusObservable struct {
35+
mu sync.RWMutex
36+
subs map[int]func(StatusItem)
37+
counter int
38+
}
39+
40+
// creates a new [safelock.StatusObservable] instance
41+
func NewStatusObs() *StatusObservable {
42+
return &StatusObservable{
43+
subs: make(map[int]func(StatusItem)),
44+
}
45+
}
46+
47+
// adds a new status change subscriber, and returns the unsubscribe function
48+
func (obs *StatusObservable) Subscribe(callback func(StatusItem)) func() {
49+
obs.mu.Lock()
50+
id := obs.counter
51+
obs.subs[id] = callback
52+
obs.counter += 1
53+
obs.mu.Unlock()
54+
55+
// returns unsubscribe function
56+
return func() {
57+
obs.mu.Lock()
58+
delete(obs.subs, id)
59+
obs.mu.Unlock()
60+
}
61+
}
62+
63+
// clears all subscriptions
64+
func (obs *StatusObservable) Unsubscribe() {
65+
obs.mu.Lock()
66+
clear(obs.subs)
67+
obs.mu.Unlock()
68+
}
69+
70+
func (obs *StatusObservable) next(value StatusItem) {
71+
obs.mu.RLock()
72+
for _, callback := range obs.subs {
73+
go callback(value)
74+
}
75+
obs.mu.RUnlock()
76+
}
77+
1978
func (sl *Safelock) updateStatus(status string, percent float64) {
20-
sl.StatusObs.Trigger(StatusUpdate.Str(), status, percent)
79+
sl.StatusObs.next(StatusItem{
80+
Event: StatusUpdate,
81+
Msg: status,
82+
Percent: percent,
83+
})
2184
}

safelock/logger.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ func (sl *Safelock) log(msg string, params ...any) {
1717
}
1818
}
1919

20-
func (sl *Safelock) logStatus(status string, percent float64) {
21-
sl.log("%s (%.2f%%)\n", status, percent)
20+
func (sl *Safelock) logStatus(status StatusItem) {
21+
if status.Event == StatusUpdate {
22+
sl.log("%s (%.2f%%)\n", status.Msg, status.Percent)
23+
}
2224
}

safelock/safelock.go

+22-5
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@ package safelock
22

33
import (
44
"context"
5+
"crypto/rand"
6+
"fmt"
57
"io"
68
"runtime"
79

8-
"github.com/GianlucaGuarini/go-observable"
910
"github.com/klauspost/compress/zstd"
1011
"github.com/mholt/archiver/v4"
1112
)
@@ -14,7 +15,7 @@ import (
1415
type EncryptionConfig struct {
1516
// encryption key length (default: 32)
1617
KeyLength uint32
17-
// encryption salt length (default: 12)
18+
// encryption salt length (default: 16)
1819
SaltLength int
1920
// number of argon2 hashing iterations (default: 3)
2021
IterationCount uint32
@@ -26,6 +27,21 @@ type EncryptionConfig struct {
2627
MinPasswordLength int
2728
// ratio to create file header size based on (default: 1024 * 4)
2829
HeaderRatio int
30+
31+
random chan []byte
32+
}
33+
34+
func (ec *EncryptionConfig) loadRandom(errs chan error) {
35+
for {
36+
nonce := make([]byte, 50)
37+
38+
if _, err := rand.Read(nonce); err != nil {
39+
errs <- fmt.Errorf("failed to generate random bytes > %w", err)
40+
return
41+
}
42+
43+
ec.random <- nonce
44+
}
2945
}
3046

3147
// archiving and compression configuration settings
@@ -49,7 +65,7 @@ type Safelock struct {
4965
// disable all output and logs (default: false)
5066
Quiet bool
5167
// observable instance that allows us to stream the status to multiple listeners
52-
StatusObs *observable.Observable
68+
StatusObs *StatusObservable
5369
}
5470

5571
// creates a new [safelock.Safelock] instance with the default recommended options
@@ -66,12 +82,13 @@ func New() *Safelock {
6682
EncryptionConfig: EncryptionConfig{
6783
IterationCount: 3,
6884
KeyLength: 32,
69-
SaltLength: 12,
85+
SaltLength: 16,
7086
MinPasswordLength: 8,
7187
HeaderRatio: 1024 * 4,
7288
MemSize: 64 * 1024,
7389
Threads: uint8(runtime.NumCPU()),
90+
random: make(chan []byte, 500),
7491
},
75-
StatusObs: observable.New(),
92+
StatusObs: NewStatusObs(),
7693
}
7794
}

safelock/writer.go

+1-7
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,7 @@ func newWriter(
3333
}
3434

3535
func (sw *safelockWriter) Write(chunk []byte) (written int, err error) {
36-
var encrypted []byte
37-
38-
if encrypted, err = sw.aead.encrypt(chunk); err != nil {
39-
return 0, sw.handleErr(err)
40-
}
41-
42-
if written, err = sw.writer.Write(encrypted); err != nil {
36+
if written, err = sw.writer.Write(sw.aead.encrypt(chunk)); err != nil {
4337
err = fmt.Errorf("can't write encrypted chunk > %w", err)
4438
return written, sw.handleErr(err)
4539
}

0 commit comments

Comments
 (0)