Browse Source

fix the float comparison by using Abs(a,b) < 1e-6 approach (IEEE 754). Move the initializiation of expr helpers (#2492)

Thibault "bui" Koechlin 1 year ago
parent
commit
8f6659a2ec
3 changed files with 39 additions and 15 deletions
  1. 7 0
      pkg/exprhelpers/expr_lib.go
  2. 21 9
      pkg/exprhelpers/helpers.go
  3. 11 6
      pkg/hubtest/parser_assert.go

+ 7 - 0
pkg/exprhelpers/expr_lib.go

@@ -419,6 +419,13 @@ var exprFuncs = []exprCustomFunc{
 			new(func() (string, error)),
 		},
 	},
+	{
+		name:     "FloatApproxEqual",
+		function: FloatApproxEqual,
+		signature: []interface{}{
+			new(func(float64, float64) bool),
+		},
+	},
 }
 
 //go 1.20 "CutPrefix":              strings.CutPrefix,

+ 21 - 9
pkg/exprhelpers/helpers.go

@@ -4,6 +4,7 @@ import (
 	"bufio"
 	"encoding/base64"
 	"fmt"
+	"math"
 	"net"
 	"net/url"
 	"os"
@@ -54,6 +55,16 @@ var exprFunctionOptions []expr.Option
 var keyValuePattern = regexp.MustCompile(`(?P<key>[^=\s]+)=(?:"(?P<quoted_value>[^"\\]*(?:\\.[^"\\]*)*)"|(?P<value>[^=\s]+)|\s*)`)
 
 func GetExprOptions(ctx map[string]interface{}) []expr.Option {
+	if len(exprFunctionOptions) == 0 {
+		exprFunctionOptions = []expr.Option{}
+		for _, function := range exprFuncs {
+			exprFunctionOptions = append(exprFunctionOptions,
+				expr.Function(function.name,
+					function.function,
+					function.signature...,
+				))
+		}
+	}
 	ret := []expr.Option{}
 	ret = append(ret, exprFunctionOptions...)
 	ret = append(ret, expr.Env(ctx))
@@ -66,15 +77,6 @@ func Init(databaseClient *database.Client) error {
 	dataFileRe2 = make(map[string][]*re2.Regexp)
 	dbClient = databaseClient
 
-	exprFunctionOptions = []expr.Option{}
-	for _, function := range exprFuncs {
-		exprFunctionOptions = append(exprFunctionOptions,
-			expr.Function(function.name,
-				function.function,
-				function.signature...,
-			))
-	}
-
 	return nil
 }
 
@@ -589,6 +591,16 @@ func Match(params ...any) (any, error) {
 	return matched, nil
 }
 
+func FloatApproxEqual(params ...any) (any, error) {
+	float1 := params[0].(float64)
+	float2 := params[1].(float64)
+
+	if math.Abs(float1-float2) < 1e-6 {
+		return true, nil
+	}
+	return false, nil
+}
+
 func B64Decode(params ...any) (any, error) {
 	encoded := params[0].(string)
 	decoded, err := base64.StdEncoding.DecodeString(encoded)

+ 11 - 6
pkg/hubtest/parser_assert.go

@@ -103,7 +103,6 @@ func (p *ParserAssert) AssertFile(testFile string) error {
 		p.NbAssert += 1
 		if !ok {
 			log.Debugf("%s is FALSE", scanner.Text())
-			//fmt.SPrintf(" %s '%s'\n", emoji.RedSquare, scanner.Text())
 			failedAssert := &AssertFail{
 				File:       p.File,
 				Line:       nbLine,
@@ -112,10 +111,13 @@ func (p *ParserAssert) AssertFile(testFile string) error {
 			}
 			variableRE := regexp.MustCompile(`(?P<variable>[^  =]+) == .*`)
 			match := variableRE.FindStringSubmatch(scanner.Text())
+			variable := ""
 			if len(match) == 0 {
 				log.Infof("Couldn't get variable of line '%s'", scanner.Text())
+				variable = scanner.Text()
+			} else {
+				variable = match[1]
 			}
-			variable := match[1]
 			result, err := p.EvalExpression(variable)
 			if err != nil {
 				log.Errorf("unable to evaluate variable '%s': %s", variable, err)
@@ -123,6 +125,7 @@ func (p *ParserAssert) AssertFile(testFile string) error {
 			}
 			failedAssert.Debug[variable] = result
 			p.Fails = append(p.Fails, *failedAssert)
+
 			continue
 		}
 		//fmt.Printf(" %s '%s'\n", emoji.GreenSquare, scanner.Text())
@@ -154,13 +157,14 @@ func (p *ParserAssert) RunExpression(expression string) (interface{}, error) {
 	env := map[string]interface{}{"results": *p.TestData}
 
 	if runtimeFilter, err = expr.Compile(expression, exprhelpers.GetExprOptions(env)...); err != nil {
+		log.Errorf("failed to compile '%s' : %s", expression, err)
 		return output, err
 	}
 
 	//dump opcode in trace level
 	log.Tracef("%s", runtimeFilter.Disassemble())
 
-	output, err = expr.Run(runtimeFilter, map[string]interface{}{"results": *p.TestData})
+	output, err = expr.Run(runtimeFilter, env)
 	if err != nil {
 		log.Warningf("running : %s", expression)
 		log.Warningf("runtime error : %s", err)
@@ -251,8 +255,8 @@ func (p *ParserAssert) AutoGenParserAssert() string {
 						continue
 					}
 					base := fmt.Sprintf(`results["%s"]["%s"][%d].Evt.Unmarshaled["%s"]`, stage, parser, pidx, ekey)
-					for _, line := range p.buildUnmarshaledAssert("", eval) {
-						ret += base + line
+					for _, line := range p.buildUnmarshaledAssert(base, eval) {
+						ret += line
 					}
 				}
 				ret += fmt.Sprintf(`results["%s"]["%s"][%d].Evt.Whitelisted == %t`+"\n", stage, parser, pidx, result.Evt.Whitelisted)
@@ -284,7 +288,8 @@ func (p *ParserAssert) buildUnmarshaledAssert(ekey string, eval interface{}) []s
 	case int:
 		ret = append(ret, fmt.Sprintf(`%s == %d`+"\n", ekey, val))
 	case float64:
-		ret = append(ret, fmt.Sprintf(`%s == %f`+"\n", ekey, val))
+		ret = append(ret, fmt.Sprintf(`FloatApproxEqual(%s, %f)`+"\n",
+			ekey, val))
 	default:
 		log.Warningf("unknown type '%T' for key '%s'", val, ekey)
 	}