Skip to content

Commit 7b96351

Browse files
authored
perf(misconf): retrieve check metadata from annotations once (#8478)
Signed-off-by: nikpivkin <[email protected]>
1 parent 573502e commit 7b96351

File tree

7 files changed

+376
-197
lines changed

7 files changed

+376
-197
lines changed

pkg/iac/rego/embed.go

+7-10
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"fmt"
66
"io/fs"
7+
"maps"
78
"path/filepath"
89
"strings"
910
"sync"
@@ -13,7 +14,6 @@ import (
1314
checks "github.com/aquasecurity/trivy-checks"
1415
"github.com/aquasecurity/trivy/pkg/iac/rules"
1516
"github.com/aquasecurity/trivy/pkg/log"
16-
"github.com/aquasecurity/trivy/pkg/set"
1717
)
1818

1919
var LoadAndRegister = sync.OnceFunc(func() {
@@ -26,9 +26,7 @@ var LoadAndRegister = sync.OnceFunc(func() {
2626
if err != nil {
2727
panic(err)
2828
}
29-
for name, policy := range loadedLibs {
30-
modules[name] = policy
31-
}
29+
maps.Copy(modules, loadedLibs)
3230

3331
RegisterRegoRules(modules)
3432
})
@@ -50,7 +48,6 @@ func RegisterRegoRules(modules map[string]*ast.Module) {
5048
}
5149

5250
retriever := NewMetadataRetriever(compiler)
53-
regoCheckIDs := set.New[string]()
5451

5552
for _, module := range modules {
5653
metadata, err := retriever.RetrieveMetadata(ctx, module)
@@ -66,10 +63,6 @@ func RegisterRegoRules(modules map[string]*ast.Module) {
6663
continue
6764
}
6865

69-
if !metadata.Deprecated {
70-
regoCheckIDs.Append(metadata.AVDID)
71-
}
72-
7366
rules.Register(metadata.ToRule())
7467
}
7568
}
@@ -93,7 +86,7 @@ func LoadPoliciesFromDirs(target fs.FS, paths ...string) (map[string]*ast.Module
9386
return nil
9487
}
9588

96-
if strings.HasSuffix(filepath.Dir(filepath.ToSlash(path)), filepath.Join("advanced", "optional")) {
89+
if isOptionalChecks(path) {
9790
return fs.SkipDir
9891
}
9992

@@ -116,3 +109,7 @@ func LoadPoliciesFromDirs(target fs.FS, paths ...string) (map[string]*ast.Module
116109
}
117110
return modules, nil
118111
}
112+
113+
func isOptionalChecks(path string) bool {
114+
return strings.HasSuffix(filepath.Dir(filepath.ToSlash(path)), filepath.Join("advanced", "optional"))
115+
}

pkg/iac/rego/embed_test.go

+24-8
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
1-
package rego
1+
package rego_test
22

33
import (
44
"testing"
5+
"testing/fstest"
56

67
"github.com/open-policy-agent/opa/v1/ast"
78
"github.com/stretchr/testify/assert"
89
"github.com/stretchr/testify/require"
910

1011
checks "github.com/aquasecurity/trivy-checks"
12+
"github.com/aquasecurity/trivy/pkg/iac/rego"
1113
"github.com/aquasecurity/trivy/pkg/iac/rules"
1214
"github.com/aquasecurity/trivy/pkg/iac/scan"
1315
)
1416

1517
func Test_EmbeddedLoading(t *testing.T) {
16-
LoadAndRegister()
18+
rego.LoadAndRegister()
1719

1820
frameworkRules := rules.GetRegistered()
1921
var found bool
@@ -87,19 +89,19 @@ deny[res]{
8789

8890
for _, tc := range testCases {
8991
t.Run(tc.name, func(t *testing.T) {
90-
policies, err := LoadPoliciesFromDirs(checks.EmbeddedLibraryFileSystem, ".")
92+
policies, err := rego.LoadPoliciesFromDirs(checks.EmbeddedLibraryFileSystem, ".")
9193
require.NoError(t, err)
92-
newRule, err := ParseRegoModule("/rules/newrule.rego", tc.inputPolicy)
94+
newRule, err := rego.ParseRegoModule("/rules/newrule.rego", tc.inputPolicy)
9395
require.NoError(t, err)
9496

9597
policies["/rules/newrule.rego"] = newRule
9698
switch {
9799
case tc.expectedError:
98100
assert.Panics(t, func() {
99-
RegisterRegoRules(policies)
101+
rego.RegisterRegoRules(policies)
100102
}, tc.name)
101103
default:
102-
RegisterRegoRules(policies)
104+
rego.RegisterRegoRules(policies)
103105
}
104106
})
105107
}
@@ -185,12 +187,12 @@ deny[res]{
185187
for _, tc := range testCases {
186188
t.Run(tc.name, func(t *testing.T) {
187189
policies := make(map[string]*ast.Module)
188-
newRule, err := ParseRegoModule("/rules/newrule.rego", tc.inputPolicy)
190+
newRule, err := rego.ParseRegoModule("/rules/newrule.rego", tc.inputPolicy)
189191
require.NoError(t, err)
190192

191193
policies["/rules/newrule.rego"] = newRule
192194
assert.NotPanics(t, func() {
193-
RegisterRegoRules(policies)
195+
rego.RegisterRegoRules(policies)
194196
})
195197

196198
for _, rule := range rules.GetRegistered() {
@@ -201,3 +203,17 @@ deny[res]{
201203
})
202204
}
203205
}
206+
207+
func TestLoadPoliciesFromDirs(t *testing.T) {
208+
fsys := fstest.MapFS{
209+
"check.rego": &fstest.MapFile{Data: []byte(`package user.foo`)},
210+
".check.rego": &fstest.MapFile{Data: []byte(`package user.foo`)},
211+
"check_test.rego": &fstest.MapFile{Data: []byte(`package user.foo_test`)},
212+
"test.yaml": &fstest.MapFile{Data: []byte(`foo: bar`)},
213+
"checks/test.rego": &fstest.MapFile{Data: []byte(`package user.checks.foo`)},
214+
}
215+
216+
modules, err := rego.LoadPoliciesFromDirs(fsys, ".")
217+
require.NoError(t, err)
218+
assert.Len(t, modules, 2)
219+
}

pkg/iac/rego/load.go

+94-36
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import (
55
"fmt"
66
"io"
77
"io/fs"
8+
"maps"
9+
"slices"
810
"strings"
911

1012
"github.com/open-policy-agent/opa/v1/ast"
@@ -13,6 +15,7 @@ import (
1315

1416
"github.com/aquasecurity/trivy/pkg/log"
1517
"github.com/aquasecurity/trivy/pkg/set"
18+
"github.com/aquasecurity/trivy/pkg/version/doc"
1619
)
1720

1821
var builtinNamespaces = set.New("builtin", "defsec", "appshield")
@@ -27,6 +30,10 @@ func IsBuiltinNamespace(namespace string) bool {
2730
})
2831
}
2932

33+
func getModuleNamespace(module *ast.Module) string {
34+
return strings.TrimPrefix(module.Package.Path.String(), "data.")
35+
}
36+
3037
func IsRegoFile(name string) bool {
3138
return strings.HasSuffix(name, bundle.RegoExt) && !strings.HasSuffix(name, "_test"+bundle.RegoExt)
3239
}
@@ -99,9 +106,7 @@ func (s *Scanner) LoadPolicies(srcFS fs.FS) error {
99106
if err != nil {
100107
return fmt.Errorf("failed to load rego checks from %s: %w", s.policyDirs, err)
101108
}
102-
for name, policy := range loaded {
103-
s.policies[name] = policy
104-
}
109+
maps.Copy(s.policies, loaded)
105110
s.logger.Debug("Checks from disk are loaded", log.Int("count", len(loaded)))
106111
}
107112

@@ -110,9 +115,7 @@ func (s *Scanner) LoadPolicies(srcFS fs.FS) error {
110115
if err != nil {
111116
return fmt.Errorf("failed to load rego checks from reader(s): %w", err)
112117
}
113-
for name, policy := range loaded {
114-
s.policies[name] = policy
115-
}
118+
maps.Copy(s.policies, loaded)
116119
s.logger.Debug("Checks from readers are loaded", log.Int("count", len(loaded)))
117120
}
118121

@@ -193,13 +196,13 @@ func (s *Scanner) findMatchedEmbeddedCheck(badPolicy *ast.Module) *ast.Module {
193196
}
194197
}
195198

196-
badPolicyMeta, err := metadataFromRegoModule(badPolicy)
199+
badPolicyMeta, err := MetadataFromAnnotations(badPolicy)
197200
if err != nil {
198201
return nil
199202
}
200203

201204
for _, embeddedCheck := range s.embeddedChecks {
202-
meta, err := metadataFromRegoModule(embeddedCheck)
205+
meta, err := MetadataFromAnnotations(embeddedCheck)
203206
if err != nil {
204207
continue
205208
}
@@ -230,6 +233,9 @@ func (s *Scanner) prunePoliciesWithError(compiler *ast.Compiler) error {
230233
}
231234

232235
func (s *Scanner) compilePolicies(srcFS fs.FS, paths []string) error {
236+
for path, module := range s.policies {
237+
s.handleModulesMetadata(path, module)
238+
}
233239

234240
schemaSet, err := BuildSchemaSetFromPolicies(s.policies, paths, srcFS, s.customSchemas)
235241
if err != nil {
@@ -249,54 +255,106 @@ func (s *Scanner) compilePolicies(srcFS fs.FS, paths []string) error {
249255
}
250256
return s.compilePolicies(srcFS, paths)
251257
}
252-
retriever := NewMetadataRetriever(compiler)
253258

254-
if err := s.filterModules(retriever); err != nil {
255-
return err
259+
s.retriever = NewMetadataRetriever(compiler)
260+
261+
if err := s.filterModules(); err != nil {
262+
return fmt.Errorf("filter modules: %w", err)
256263
}
257264
s.compiler = compiler
258-
s.retriever = retriever
259265
return nil
260266
}
261267

262-
func (s *Scanner) filterModules(retriever *MetadataRetriever) error {
268+
func (s *Scanner) handleModulesMetadata(path string, module *ast.Module) {
269+
if moduleHasLegacyInputFormat(module) {
270+
s.logger.Warn(
271+
"Module has legacy input format - please update to use annotations",
272+
log.FilePath(module.Package.Location.File),
273+
log.String("details", doc.URL("/docs/scanner/misconfiguration/custom", "input")),
274+
)
275+
}
276+
277+
if moduleHasLegacyMetadataFormat(module) {
278+
s.logger.Warn(
279+
"Module has legacy metadata format - please update to use annotations",
280+
log.FilePath(module.Package.Location.File),
281+
log.String("details", doc.URL("/docs/scanner/misconfiguration/custom", "metadata")),
282+
)
283+
return
284+
}
285+
286+
metadata, err := MetadataFromAnnotations(module)
287+
if err != nil {
288+
s.logger.Error(
289+
"Failed to retrieve metadata from annotations",
290+
log.FilePath(module.Package.Location.File),
291+
log.Err(err),
292+
)
293+
return
294+
}
263295

296+
if metadata != nil {
297+
s.moduleMetadata[path] = metadata
298+
}
299+
}
300+
301+
// moduleHasLegacyMetadataFormat checks if the module has a legacy metadata format.
302+
// Returns true if the metadata is represented as a “__rego_metadata__” rule,
303+
// which was used before annotations were introduced.
304+
func moduleHasLegacyMetadataFormat(module *ast.Module) bool {
305+
return slices.ContainsFunc(module.Rules, func(rule *ast.Rule) bool {
306+
return rule.Head.Name.Equal(ast.Var("__rego_metadata__"))
307+
})
308+
}
309+
310+
// moduleHasLegacyInputFormat checks if the module has a legacy input format.
311+
// Returns true if the input is represented as a “__rego_input__” rule,
312+
// which was used before annotations were introduced.
313+
func moduleHasLegacyInputFormat(module *ast.Module) bool {
314+
return slices.ContainsFunc(module.Rules, func(rule *ast.Rule) bool {
315+
return rule.Head.Name.Equal(ast.Var("__rego_input__"))
316+
})
317+
}
318+
319+
// filterModules filters the Rego modules based on metadata.
320+
func (s *Scanner) filterModules() error {
264321
filtered := make(map[string]*ast.Module)
265322
for name, module := range s.policies {
266-
meta, err := retriever.RetrieveMetadata(context.TODO(), module)
323+
metadata, err := s.metadataForModule(context.Background(), name, module, nil)
267324
if err != nil {
268-
return err
325+
return fmt.Errorf("retrieve metadata for module %s: %w", name, err)
269326
}
270327

271-
if !meta.hasAnyFramework(s.frameworks) {
272-
continue
273-
}
274-
275-
if IsBuiltinNamespace(getModuleNamespace(module)) {
276-
if s.disabledCheckIDs.Contains(meta.ID) { // ignore builtin disabled checks
277-
continue
278-
}
279-
}
280-
281-
if len(meta.InputOptions.Selectors) == 0 {
282-
if !meta.Library {
283-
s.logger.Warn(
284-
"Module has no input selectors - it will be loaded for all inputs!",
285-
log.FilePath(module.Package.Location.File),
286-
log.String("module", name),
287-
)
288-
}
328+
if s.isModuleApplicable(module, metadata, name) {
289329
filtered[name] = module
290-
continue
291330
}
292-
293-
filtered[name] = module
294331
}
295332

296333
s.policies = filtered
297334
return nil
298335
}
299336

337+
func (s *Scanner) isModuleApplicable(module *ast.Module, metadata *StaticMetadata, name string) bool {
338+
if !metadata.hasAnyFramework(s.frameworks) {
339+
return false
340+
}
341+
342+
// ignore disabled built-in checks
343+
if IsBuiltinNamespace(getModuleNamespace(module)) && s.disabledCheckIDs.Contains(metadata.ID) {
344+
return false
345+
}
346+
347+
if len(metadata.InputOptions.Selectors) == 0 && !metadata.Library {
348+
s.logger.Warn(
349+
"Module has no input selectors - it will be loaded for all inputs",
350+
log.FilePath(module.Package.Location.File),
351+
log.String("module", name),
352+
)
353+
}
354+
355+
return true
356+
}
357+
300358
func ParseRegoModule(name, input string) (*ast.Module, error) {
301359
return ast.ParseModuleWithOpts(name, input, ast.ParserOptions{
302360
ProcessAnnotation: true,

0 commit comments

Comments
 (0)