Browse Source

Detect unnecessary imports instead of hardcoding

lu4p 4 months ago
parent
commit
237e0b7b7c
2 changed files with 126 additions and 162 deletions
  1. 47 42
      main.go
  2. 79 120
      runtime_strip.go

+ 47 - 42
main.go

@@ -748,6 +748,7 @@ func transformCompile(args []string) ([]string, error) {
 		if curPkg.ImportPath == "runtime" && flagTiny {
 			// strip unneeded runtime code
 			stripRuntime(filename, file)
+			tf.removeUnnecessaryImports(file)
 		}
 		tf.handleDirectives(file.Comments)
 		file = tf.transformGo(file)
@@ -1441,6 +1442,51 @@ func recordedAsNotObfuscated(obj types.Object) bool {
 	return ok
 }
 
+func (tf *transformer) removeUnnecessaryImports(file *ast.File) {
+	usedImports := make(map[string]bool)
+	ast.Inspect(file, func(n ast.Node) bool {
+		node, ok := n.(*ast.Ident)
+		if !ok {
+			return true
+		}
+
+		uses, ok := tf.info.Uses[node].(*types.PkgName)
+		if !ok {
+			return true
+		}
+
+		usedImports[uses.Imported().Path()] = true
+
+		return true
+	})
+
+	for _, imp := range file.Imports {
+		if imp.Name != nil && (imp.Name.Name == "_" || imp.Name.Name == ".") {
+			continue
+		}
+
+		path, err := strconv.Unquote(imp.Path.Value)
+		if err != nil {
+			panic(err)
+		}
+
+		// The import path can't be used directly here, because the actual
+		// path resolved via go/types might be different from the naive path.
+		lpkg, err := listPackage(path)
+		if err != nil {
+			panic(err)
+		}
+
+		if usedImports[lpkg.ImportPath] {
+			continue
+		}
+
+		if !astutil.DeleteImport(fset, file, path) {
+			panic(fmt.Sprintf("cannot delete unused import: %q", path))
+		}
+	}
+}
+
 // transformGo obfuscates the provided Go syntax file.
 func (tf *transformer) transformGo(file *ast.File) *ast.File {
 	// Only obfuscate the literals here if the flag is on
@@ -1453,48 +1499,7 @@ func (tf *transformer) transformGo(file *ast.File) *ast.File {
 		file = literals.Obfuscate(file, tf.info, fset, tf.linkerVariableStrings)
 
 		// some imported constants might not be needed anymore, remove unnecessary imports
-		usedImports := make(map[string]bool)
-		ast.Inspect(file, func(n ast.Node) bool {
-			node, ok := n.(*ast.Ident)
-			if !ok {
-				return true
-			}
-
-			uses, ok := tf.info.Uses[node].(*types.PkgName)
-			if !ok {
-				return true
-			}
-
-			usedImports[uses.Imported().Path()] = true
-
-			return true
-		})
-
-		for _, imp := range file.Imports {
-			if imp.Name != nil && (imp.Name.Name == "_" || imp.Name.Name == ".") {
-				continue
-			}
-
-			path, err := strconv.Unquote(imp.Path.Value)
-			if err != nil {
-				panic(err)
-			}
-
-			// The import path can't be used directly here, because the actual
-			// path resolved via go/types might be different from the naive path.
-			lpkg, err := listPackage(path)
-			if err != nil {
-				panic(err)
-			}
-
-			if usedImports[lpkg.ImportPath] {
-				continue
-			}
-
-			if !astutil.DeleteImport(fset, file, path) {
-				panic(fmt.Sprintf("cannot delete unused import: %q", path))
-			}
-		}
+		tf.removeUnnecessaryImports(file)
 	}
 
 	pre := func(cursor *astutil.Cursor) bool {

+ 79 - 120
runtime_strip.go

@@ -5,7 +5,6 @@ package main
 
 import (
 	"go/ast"
-	"go/token"
 	"strings"
 
 	ah "mvdan.cc/garble/internal/asthelper"
@@ -35,121 +34,93 @@ func stripRuntime(filename string, file *ast.File) {
 	}
 
 	for _, decl := range file.Decls {
-		switch x := decl.(type) {
-		case *ast.FuncDecl:
-			switch filename {
-			case "error.go":
-				// only used in panics
-				switch x.Name.Name {
-				case "printany", "printanycustomtype":
-					x.Body.List = nil
-				}
-			case "mgcscavenge.go":
-				// used in tracing the scavenger
-				if x.Name.Name == "printScavTrace" {
-					x.Body.List = nil
-					break
-				}
-			case "mprof.go":
-				// remove all functions that print debug/tracing info
-				// of the runtime
-				if strings.HasPrefix(x.Name.Name, "trace") {
-					x.Body.List = nil
-				}
-			case "panic.go":
-				// used for printing panics
-				switch x.Name.Name {
-				case "preprintpanics", "printpanics":
-					x.Body.List = nil
-				}
-			case "print.go":
-				// only used in tracebacks
-				if x.Name.Name == "hexdumpWords" {
-					x.Body.List = nil
-					break
-				}
-			case "proc.go":
-				// used in tracing the scheduler
-				if x.Name.Name == "schedtrace" {
-					x.Body.List = nil
-					break
-				}
-			case "runtime1.go":
-				usesEnv := func(node ast.Node) bool {
-					seen := false
-					ast.Inspect(node, func(node ast.Node) bool {
-						ident, ok := node.(*ast.Ident)
-						if ok && ident.Name == "gogetenv" {
-							seen = true
-							return false
-						}
-						return true
-					})
-					return seen
-				}
-			filenames:
-				switch x.Name.Name {
-				case "parsedebugvars":
-					// keep defaults for GODEBUG cgocheck and invalidptr,
-					// remove code that reads GODEBUG via gogetenv
-					for i, stmt := range x.Body.List {
-						if usesEnv(stmt) {
-							x.Body.List = x.Body.List[:i]
-							break filenames
-						}
+		funcDecl, ok := decl.(*ast.FuncDecl)
+		if !ok {
+			continue
+		}
+
+		switch filename {
+		case "error.go":
+			// only used in panics
+			switch funcDecl.Name.Name {
+			case "printany", "printanycustomtype":
+				funcDecl.Body.List = nil
+			}
+		case "mgcscavenge.go":
+			// used in tracing the scavenger
+			if funcDecl.Name.Name == "printScavTrace" {
+				funcDecl.Body.List = nil
+			}
+		case "mprof.go":
+			// remove all functions that print debug/tracing info
+			// of the runtime
+			if strings.HasPrefix(funcDecl.Name.Name, "trace") {
+				funcDecl.Body.List = nil
+			}
+		case "panic.go":
+			// used for printing panics
+			switch funcDecl.Name.Name {
+			case "preprintpanics", "printpanics":
+				funcDecl.Body.List = nil
+			}
+		case "print.go":
+			// only used in tracebacks
+			if funcDecl.Name.Name == "hexdumpWords" {
+				funcDecl.Body.List = nil
+			}
+		case "proc.go":
+			// used in tracing the scheduler
+			if funcDecl.Name.Name == "schedtrace" {
+				funcDecl.Body.List = nil
+			}
+		case "runtime1.go":
+			usesEnv := func(node ast.Node) bool {
+				seen := false
+				ast.Inspect(node, func(node ast.Node) bool {
+					ident, ok := node.(*ast.Ident)
+					if ok && ident.Name == "gogetenv" {
+						seen = true
+						return false
 					}
-					panic("did not see any gogetenv call in parsedebugvars")
-				case "setTraceback":
-					// tracebacks are completely hidden, no
-					// sense keeping this function
-					x.Body.List = nil
-				}
-			case "traceback.go":
-				// only used for printing tracebacks
-				switch x.Name.Name {
-				case "tracebackdefers", "printcreatedby", "printcreatedby1", "traceback", "tracebacktrap", "traceback1", "printAncestorTraceback",
-					"printAncestorTracebackFuncInfo", "goroutineheader", "tracebackothers", "tracebackHexdump", "printCgoTraceback":
-					x.Body.List = nil
-				case "printOneCgoTraceback":
-					x.Body = ah.BlockStmt(ah.ReturnStmt(ah.IntLit(0)))
-				default:
-					if strings.HasPrefix(x.Name.Name, "print") {
-						x.Body.List = nil
+					return true
+				})
+				return seen
+			}
+		filenames:
+			switch funcDecl.Name.Name {
+			case "parsedebugvars":
+				// keep defaults for GODEBUG cgocheck and invalidptr,
+				// remove code that reads GODEBUG via gogetenv
+				for i, stmt := range funcDecl.Body.List {
+					if usesEnv(stmt) {
+						funcDecl.Body.List = funcDecl.Body.List[:i]
+						break filenames
 					}
 				}
-			default:
-				break
-			}
-		case *ast.GenDecl:
-			if x.Tok != token.IMPORT {
-				continue
+				panic("did not see any gogetenv call in parsedebugvars")
+			case "setTraceback":
+				// tracebacks are completely hidden, no
+				// sense keeping this function
+				funcDecl.Body.List = nil
 			}
-
-			switch filename {
-			case "print.go":
-				// was used in hexdumpWords
-				x.Specs = removeImport(`"internal/goarch"`, x.Specs)
-			case "traceback.go":
-				// was used in traceback1
-				x.Specs = removeImport(`"runtime/internal/atomic"`, x.Specs)
+		case "traceback.go":
+			// only used for printing tracebacks
+			switch funcDecl.Name.Name {
+			case "tracebackdefers", "printcreatedby", "printcreatedby1", "traceback", "tracebacktrap", "traceback1", "printAncestorTraceback",
+				"printAncestorTracebackFuncInfo", "goroutineheader", "tracebackothers", "tracebackHexdump", "printCgoTraceback":
+				funcDecl.Body.List = nil
+			case "printOneCgoTraceback":
+				funcDecl.Body = ah.BlockStmt(ah.ReturnStmt(ah.IntLit(0)))
+			default:
+				if strings.HasPrefix(funcDecl.Name.Name, "print") {
+					funcDecl.Body.List = nil
+				}
 			}
-
 		}
+
 	}
 
-	switch filename {
-	case "runtime1.go":
-		// On Go 1.17.x, the code above results in runtime1.go having an
-		// unused import. Make it an underscore import.
-		// If this is a recurring problem, we could go for a more
-		// generic solution like x/tools/imports.
-		for _, imp := range file.Imports {
-			if imp.Path.Value == `"internal/bytealg"` {
-				imp.Name = &ast.Ident{Name: "_"}
-				break
-			}
-		}
-	case "print.go":
+	if filename == "print.go" {
 		file.Decls = append(file.Decls, hidePrintDecl)
 		return
 	}
@@ -160,18 +131,6 @@ func stripRuntime(filename string, file *ast.File) {
 	ast.Inspect(file, stripPrints)
 }
 
-func removeImport(importPath string, specs []ast.Spec) []ast.Spec {
-	for i, spec := range specs {
-		imp := spec.(*ast.ImportSpec)
-		if imp.Path.Value == importPath {
-			specs = append(specs[:i], specs[i+1:]...)
-			break
-		}
-	}
-
-	return specs
-}
-
 var hidePrintDecl = &ast.FuncDecl{
 	Name: ast.NewIdent("hidePrint"),
 	Type: &ast.FuncType{Params: &ast.FieldList{