Browse Source

support code taking the address of a []byte literal (#530)

shellhazard 3 months ago
parent
commit
22e3d30216
3 changed files with 132 additions and 39 deletions
  1. 1 0
      AUTHORS
  2. 102 39
      internal/literals/literals.go
  3. 29 0
      testdata/scripts/literals.txt

+ 1 - 0
AUTHORS

@@ -12,3 +12,4 @@ Nicholas Jones <me@nicholasjon.es>
 Zachary Wasserman <zachwass2000@gmail.com>
 lu4p <lu4p@pm.me>
 pagran <pagran@protonmail.com>
+shellhazard <shellhazard@tutanota.com>

+ 102 - 39
internal/literals/literals.go

@@ -77,59 +77,102 @@ func Obfuscate(file *ast.File, info *types.Info, fset *token.FileSet, linkString
 			return true
 		}
 
-		if node, ok := node.(*ast.CompositeLit); ok {
-			if len(node.Elts) == 0 || len(node.Elts) > maxSizeBytes {
+		switch node := node.(type) {
+		case *ast.UnaryExpr:
+			// Account for the possibility of address operators like
+			// &[]byte used inline with function arguments.
+			//
+			// See issue #520.
+
+			if node.Op != token.AND {
 				return true
 			}
 
-			byteType := types.Universe.Lookup("byte").Type()
-
-			var arrayLen int64
-			switch y := info.TypeOf(node.Type).(type) {
-			case *types.Array:
-				if y.Elem() != byteType {
-					return true
-				}
-
-				arrayLen = y.Len()
-
-			case *types.Slice:
-				if y.Elem() != byteType {
-					return true
+			if child, ok := node.X.(*ast.CompositeLit); ok {
+				newnode := handleCompositeLiteral(true, child, info)
+				if newnode != nil {
+					cursor.Replace(newnode)
 				}
+			}
 
-			default:
+		case *ast.CompositeLit:
+			// We replaced the &[]byte{...} case above. Here we account for the
+			// standard []byte{...} or [4]byte{...} value form.
+			//
+			// We need two separate calls to cursor.Replace, as it only supports
+			// replacing the node we're currently visiting, and the pointer variant
+			// requires us to move the ampersand operator.
+
+			parent, ok := cursor.Parent().(*ast.UnaryExpr)
+			if ok && parent.Op == token.AND {
 				return true
 			}
 
-			data := make([]byte, 0, len(node.Elts))
+			newnode := handleCompositeLiteral(false, node, info)
+			if newnode != nil {
+				cursor.Replace(newnode)
+			}
+		}
 
-			for _, el := range node.Elts {
-				elType := info.Types[el]
+		return true
+	}
 
-				if elType.Value == nil || elType.Value.Kind() != constant.Int {
-					return true
-				}
+	return astutil.Apply(file, pre, post).(*ast.File)
+}
 
-				value, ok := constant.Uint64Val(elType.Value)
-				if !ok {
-					panic(fmt.Sprintf("cannot parse byte value: %v", elType.Value))
-				}
+// handleCompositeLiteral checks if the input node is []byte or [...]byte and
+// calls the appropriate obfuscation method, returning a new node that should
+// be used to replace it.
+//
+// If the input is not a byte slice or array, the node is returned as-is and
+// the second return value will be false.
+func handleCompositeLiteral(isPointer bool, node *ast.CompositeLit, info *types.Info) ast.Node {
+	if len(node.Elts) == 0 || len(node.Elts) > maxSizeBytes {
+		return nil
+	}
 
-				data = append(data, byte(value))
-			}
+	byteType := types.Universe.Lookup("byte").Type()
 
-			if arrayLen > 0 {
-				cursor.Replace(withPos(obfuscateByteArray(data, arrayLen), node.Pos()))
-			} else {
-				cursor.Replace(withPos(obfuscateByteSlice(data), node.Pos()))
-			}
+	var arrayLen int64
+	switch y := info.TypeOf(node.Type).(type) {
+	case *types.Array:
+		if y.Elem() != byteType {
+			return nil
 		}
 
-		return true
+		arrayLen = y.Len()
+
+	case *types.Slice:
+		if y.Elem() != byteType {
+			return nil
+		}
+
+	default:
+		return nil
 	}
 
-	return astutil.Apply(file, pre, post).(*ast.File)
+	data := make([]byte, 0, len(node.Elts))
+
+	for _, el := range node.Elts {
+		elType := info.Types[el]
+
+		if elType.Value == nil || elType.Value.Kind() != constant.Int {
+			return nil
+		}
+
+		value, ok := constant.Uint64Val(elType.Value)
+		if !ok {
+			panic(fmt.Sprintf("cannot parse byte value: %v", elType.Value))
+		}
+
+		data = append(data, byte(value))
+	}
+
+	if arrayLen > 0 {
+		return withPos(obfuscateByteArray(isPointer, data, arrayLen), node.Pos())
+	}
+
+	return withPos(obfuscateByteSlice(isPointer, data), node.Pos())
 }
 
 // withPos sets any token.Pos fields under node which affect printing to pos.
@@ -186,14 +229,25 @@ func obfuscateString(data string) *ast.CallExpr {
 	return ah.LambdaCall(ast.NewIdent("string"), block)
 }
 
-func obfuscateByteSlice(data []byte) *ast.CallExpr {
+func obfuscateByteSlice(isPointer bool, data []byte) *ast.CallExpr {
 	obfuscator := randObfuscator()
 	block := obfuscator.obfuscate(data)
+
+	if isPointer {
+		block.List = append(block.List, ah.ReturnStmt(&ast.UnaryExpr{
+			Op: token.AND,
+			X:  ast.NewIdent("data"),
+		}))
+		return ah.LambdaCall(&ast.StarExpr{
+			X: &ast.ArrayType{Elt: ast.NewIdent("byte")},
+		}, block)
+	}
+
 	block.List = append(block.List, ah.ReturnStmt(ast.NewIdent("data")))
 	return ah.LambdaCall(&ast.ArrayType{Elt: ast.NewIdent("byte")}, block)
 }
 
-func obfuscateByteArray(data []byte, length int64) *ast.CallExpr {
+func obfuscateByteArray(isPointer bool, data []byte, length int64) *ast.CallExpr {
 	obfuscator := randObfuscator()
 	block := obfuscator.obfuscate(data)
 
@@ -224,10 +278,19 @@ func obfuscateByteArray(data []byte, length int64) *ast.CallExpr {
 				},
 			}},
 		},
-		ah.ReturnStmt(ast.NewIdent("newdata")),
 	}
 
+	var retexpr ast.Expr = ast.NewIdent("newdata")
+	if isPointer {
+		retexpr = &ast.UnaryExpr{X: retexpr, Op: token.AND}
+	}
+
+	sliceToArray = append(sliceToArray, ah.ReturnStmt(retexpr))
 	block.List = append(block.List, sliceToArray...)
 
+	if isPointer {
+		return ah.LambdaCall(&ast.StarExpr{X: arrayType}, block)
+	}
+
 	return ah.LambdaCall(arrayType, block)
 }

+ 29 - 0
testdata/scripts/literals.txt

@@ -264,6 +264,31 @@ func byteTest() {
 
 	e := []byte{0x43, 11_1, 0b01101101, 'p', 'l', 'e', 'x'}
 	println(string(e))
+
+	// Testing for issue #520.
+	func(s []byte) {
+		print(string(s))
+	}([]byte("chungus"))
+	println()
+
+	func(s *[]byte) {
+		print(string(*s))
+	}(&[]byte{99, 104, 117, 110, 103, 117, 115})
+	println()
+
+	func(s [7]byte) {
+		for _, elm := range s {
+			print(elm, ",")
+		}
+	}([7]byte{99, 104, 117, 110, 103, 117, 115})
+	println()
+
+	func(s *[7]byte) {
+		for _, elm := range s {
+			print(elm, ",")
+		}
+	}(&[7]byte{99, 104, 117, 110, 103, 117, 115})
+	println()
 }
 
 func stringTypeFunc(s stringType) stringType {
@@ -376,6 +401,10 @@ foo
 12,13,
 12,13,0,0,
 Complex
+chungus
+chungus
+99,104,117,110,103,117,115,
+99,104,117,110,103,117,115,
 obfuscated with shadowed builtins (vars)
 obfuscated with shadowed builtins (types)
 1: literal in an array