diff --git a/build.go b/build.go index 4a2c614..a93c8eb 100644 --- a/build.go +++ b/build.go @@ -45,28 +45,9 @@ type builder struct { // axisPredicate creates a predicate to predicating for this axis node. func axisPredicate(root *axisNode) func(NodeNavigator) bool { - // get current axix node type. - typ := ElementNode - switch root.AxeType { - case "attribute": - typ = AttributeNode - case "self", "parent": - typ = allNode - default: - switch root.Prop { - case "comment": - typ = CommentNode - case "text": - typ = TextNode - // case "processing-instruction": - // typ = ProcessingInstructionNode - case "node": - typ = allNode - } - } nametest := root.LocalName != "" || root.Prefix != "" predicate := func(n NodeNavigator) bool { - if typ == n.NodeType() || typ == allNode { + if root.typeTest == n.NodeType() || root.typeTest == allNode { if nametest { type namespaceURL interface { NamespaceURL() string @@ -102,42 +83,35 @@ func (b *builder) processAxis(root *axisNode, flags flag, props *builderProp) (q *props = builderProps.None } else { inputFlags := flagsEnum.None - if root.AxeType == "child" && (root.Input.Type() == nodeAxis) { - if input := root.Input.(*axisNode); input.AxeType == "descendant-or-self" { - var qyGrandInput query - if input.Input != nil { - qyGrandInput, err = b.processNode(input.Input, flagsEnum.SmartDesc, props) - if err != nil { - return nil, err + if (flags & flagsEnum.Filter) == 0 { + if root.AxisType == "child" && (root.Input.Type() == nodeAxis) { + if input := root.Input.(*axisNode); input.AxisType == "descendant-or-self" { + var qyGrandInput query + if input.Input != nil { + qyGrandInput, err = b.processNode(input.Input, flagsEnum.SmartDesc, props) + if err != nil { + return nil, err + } + } else { + qyGrandInput = &contextQuery{} } - } else { - qyGrandInput = &contextQuery{} + qyOutput = &descendantQuery{name: root.LocalName, Input: qyGrandInput, Predicate: predicate, Self: false} + *props |= builderProps.NonFlat + return qyOutput, nil } - // fix #20: https://github.com/antchfx/htmlquery/issues/20 - filter := func(n NodeNavigator) bool { - v := predicate(n) - switch root.Prop { - case "text": - v = v && n.NodeType() == TextNode - case "comment": - v = v && n.NodeType() == CommentNode - } - return v - } - qyOutput = &descendantQuery{name: root.LocalName, Input: qyGrandInput, Predicate: filter, Self: false} - *props |= builderProps.NonFlat - return qyOutput, nil } - } else if ((flags & flagsEnum.Filter) == 0) && (root.AxeType == "descendant" || root.AxeType == "descendant-or-self") { - inputFlags |= flagsEnum.SmartDesc + if root.AxisType == "descendant" || root.AxisType == "descendant-or-self" { + inputFlags |= flagsEnum.SmartDesc + } } + qyInput, err = b.processNode(root.Input, inputFlags, props) if err != nil { return nil, err } } - switch root.AxeType { + switch root.AxisType { case "ancestor": qyOutput = &ancestorQuery{name: root.LocalName, Input: qyInput, Predicate: predicate} *props |= builderProps.NonFlat @@ -147,22 +121,10 @@ func (b *builder) processAxis(root *axisNode, flags flag, props *builderProp) (q case "attribute": qyOutput = &attributeQuery{name: root.LocalName, Input: qyInput, Predicate: predicate} case "child": - filter := func(n NodeNavigator) bool { - v := predicate(n) - switch root.Prop { - case "text": - v = v && n.NodeType() == TextNode - case "node": - v = v && (n.NodeType() == ElementNode || n.NodeType() == TextNode) - case "comment": - v = v && n.NodeType() == CommentNode - } - return v - } if (*props & builderProps.NonFlat) == 0 { - qyOutput = &childQuery{name: root.LocalName, Input: qyInput, Predicate: filter} + qyOutput = &childQuery{name: root.LocalName, Input: qyInput, Predicate: predicate} } else { - qyOutput = &cachedChildQuery{name: root.LocalName, Input: qyInput, Predicate: filter} + qyOutput = &cachedChildQuery{name: root.LocalName, Input: qyInput, Predicate: predicate} } case "descendant": if (flags & flagsEnum.SmartDesc) != flagsEnum.None { @@ -195,7 +157,7 @@ func (b *builder) processAxis(root *axisNode, flags flag, props *builderProp) (q case "namespace": // haha,what will you do someting?? default: - err = fmt.Errorf("unknown axe type: %s", root.AxeType) + err = fmt.Errorf("unknown axe type: %s", root.AxisType) return nil, err } return qyOutput, nil @@ -238,7 +200,6 @@ func (b *builder) processFilter(root *filterNode, flags flag, props *builderProp *props |= builderProps.PosFilter } - merge := (qyInput.Properties() & queryProps.Merge) != 0 if (propsCond & builderProps.HasPosition) != builderProps.None { if (propsCond & builderProps.HasLast) != 0 { // https://github.com/antchfx/xpath/issues/76 @@ -246,16 +207,15 @@ func (b *builder) processFilter(root *filterNode, flags flag, props *builderProp if qyFunc, ok := cond.(*functionQuery); ok { switch qyFunc.Input.(type) { case *filterQuery: - cond = &lastQuery{Input: qyFunc.Input} + cond = &lastFuncQuery{Input: qyFunc.Input} } } } } + merge := (qyInput.Properties() & queryProps.Merge) != 0 if first && firstInput != nil { if merge && ((*props & builderProps.PosFilter) != 0) { - qyInput = &filterQuery{Input: qyInput, Predicate: cond, NoPosition: false} - var ( rootQuery = &contextQuery{} parent query @@ -318,10 +278,11 @@ func (b *builder) processFilter(root *filterNode, flags flag, props *builderProp } } b.firstInput = nil + child := &filterQuery{Input: qyInput, Predicate: cond, NoPosition: false} if parent != nil { - return &mergeQuery{Input: parent, Child: qyInput}, nil + return &mergeQuery{Input: parent, Child: child}, nil } - return qyInput, nil + return child, nil } b.firstInput = nil } @@ -346,7 +307,7 @@ func (b *builder) processFunction(root *functionNode, props *builderProp) (query if err != nil { return nil, err } - qyOutput = &functionQuery{Input: arg, Func: lowerCaseFunc} + qyOutput = &functionQuery{Func: lowerCaseFunc(arg)} case "starts-with": arg1, err := b.processNode(root.Args[0], flagsEnum.None, props) if err != nil { @@ -453,16 +414,13 @@ func (b *builder) processFunction(root *functionNode, props *builderProp) (query if len(root.Args) > 0 { arg = root.Args[0] } else { - arg = &axisNode{ - AxeType: "self", - nodeType: nodeAxis, - } + arg = newAxisNode("self", allNode, "", "", "", nil) } - argQuery, err := b.processNode(arg, flagsEnum.None, props) + arg1, err := b.processNode(arg, flagsEnum.None, props) if err != nil { return nil, err } - qyOutput = &functionQuery{Input: argQuery, Func: normalizespaceFunc} + qyOutput = &functionQuery{Func: normalizespaceFunc(arg1)} case "replace": //replace( string , string, string ) if len(root.Args) != 3 { @@ -509,7 +467,7 @@ func (b *builder) processFunction(root *functionNode, props *builderProp) (query if err != nil { return nil, err } - qyOutput = &functionQuery{Input: argQuery, Func: notFunc} + qyOutput = &functionQuery{Func: notFunc(argQuery)} case "name", "local-name", "namespace-uri": if len(root.Args) > 1 { return nil, fmt.Errorf("xpath: %s function must have at most one parameter", root.FuncName) @@ -540,17 +498,10 @@ func (b *builder) processFunction(root *functionNode, props *builderProp) (query }, } case "last": - //switch typ := b.firstInput.(type) { - //case *groupQuery, *filterQuery: - // https://github.com/antchfx/xpath/issues/76 - // https://github.com/antchfx/xpath/issues/78 - //qyOutput = &lastQuery{Input: typ} - //default: - qyOutput = &functionQuery{Func: lastFunc} - //} + qyOutput = &functionQuery{Input: b.firstInput, Func: lastFunc()} *props |= builderProps.HasLast case "position": - qyOutput = &functionQuery{Func: positionFunc} + qyOutput = &functionQuery{Input: b.firstInput, Func: positionFunc()} *props |= builderProps.HasPosition case "boolean", "number", "string": var inp query @@ -564,16 +515,14 @@ func (b *builder) processFunction(root *functionNode, props *builderProp) (query } inp = argQuery } - f := &functionQuery{Input: inp} switch root.FuncName { case "boolean": - f.Func = booleanFunc + qyOutput = &functionQuery{Func: booleanFunc(inp)} case "string": - f.Func = stringFunc + qyOutput = &functionQuery{Func: stringFunc(inp)} case "number": - f.Func = numberFunc + qyOutput = &functionQuery{Func: numberFunc(inp)} } - qyOutput = f case "count": if len(root.Args) == 0 { return nil, fmt.Errorf("xpath: count(node-sets) function must with have parameters node-sets") @@ -582,7 +531,7 @@ func (b *builder) processFunction(root *functionNode, props *builderProp) (query if err != nil { return nil, err } - qyOutput = &functionQuery{Input: argQuery, Func: countFunc} + qyOutput = &functionQuery{Func: countFunc(argQuery)} case "sum": if len(root.Args) == 0 { return nil, fmt.Errorf("xpath: sum(node-sets) function must with have parameters node-sets") @@ -591,7 +540,7 @@ func (b *builder) processFunction(root *functionNode, props *builderProp) (query if err != nil { return nil, err } - qyOutput = &functionQuery{Input: argQuery, Func: sumFunc} + qyOutput = &functionQuery{Func: sumFunc(argQuery)} case "ceiling", "floor", "round": if len(root.Args) == 0 { return nil, fmt.Errorf("xpath: ceiling(node-sets) function must with have parameters node-sets") @@ -600,16 +549,14 @@ func (b *builder) processFunction(root *functionNode, props *builderProp) (query if err != nil { return nil, err } - f := &functionQuery{Input: argQuery} switch root.FuncName { case "ceiling": - f.Func = ceilingFunc + qyOutput = &functionQuery{Func: ceilingFunc(argQuery)} case "floor": - f.Func = floorFunc + qyOutput = &functionQuery{Func: floorFunc(argQuery)} case "round": - f.Func = roundFunc + qyOutput = &functionQuery{Func: roundFunc(argQuery)} } - qyOutput = f case "concat": if len(root.Args) < 2 { return nil, fmt.Errorf("xpath: concat() must have at least two arguments") @@ -636,7 +583,7 @@ func (b *builder) processFunction(root *functionNode, props *builderProp) (query if len(root.Args) != 2 { return nil, fmt.Errorf("xpath: string-join(node-sets, separator) function requires node-set and argument") } - argQuery, err := b.processNode(root.Args[0], flagsEnum.None, props) + input, err := b.processNode(root.Args[0], flagsEnum.None, props) if err != nil { return nil, err } @@ -644,14 +591,10 @@ func (b *builder) processFunction(root *functionNode, props *builderProp) (query if err != nil { return nil, err } - qyOutput = &functionQuery{Input: argQuery, Func: stringJoinFunc(arg1)} + qyOutput = &functionQuery{Func: stringJoinFunc(input, arg1)} default: return nil, fmt.Errorf("not yet support this function %s()", root.FuncName) } - - if funcQuery, ok := qyOutput.(*functionQuery); ok && funcQuery.Input == nil { - funcQuery.Input = b.firstInput - } return qyOutput, nil } diff --git a/func.go b/func.go index a5e88ba..4079a19 100644 --- a/func.go +++ b/func.go @@ -37,75 +37,83 @@ func predicate(q query) func(NodeNavigator) bool { } // positionFunc is a XPath Node Set functions position(). -func positionFunc(q query, t iterator) interface{} { - var ( - count = 1 - node = t.Current().Copy() - ) - test := predicate(q) - for node.MoveToPrevious() { - if test(node) { - count++ +func positionFunc() func(query, iterator) interface{} { + return func(q query, t iterator) interface{} { + var ( + count = 1 + node = t.Current().Copy() + ) + test := predicate(q) + for node.MoveToPrevious() { + if test(node) { + count++ + } } + return float64(count) } - return float64(count) } // lastFunc is a XPath Node Set functions last(). -func lastFunc(q query, t iterator) interface{} { - var ( - count = 0 - node = t.Current().Copy() - ) - node.MoveToFirst() - test := predicate(q) - for { - if test(node) { - count++ - } - if !node.MoveToNext() { - break +func lastFunc() func(query, iterator) interface{} { + return func(q query, t iterator) interface{} { + var ( + count = 0 + node = t.Current().Copy() + ) + test := predicate(q) + node.MoveToFirst() + for { + if test(node) { + count++ + } + if !node.MoveToNext() { + break + } } + return float64(count) } - return float64(count) } // countFunc is a XPath Node Set functions count(node-set). -func countFunc(q query, t iterator) interface{} { - var count = 0 - q = functionArgs(q) - test := predicate(q) - switch typ := q.Evaluate(t).(type) { - case query: - for node := typ.Select(t); node != nil; node = typ.Select(t) { - if test(node) { - count++ +func countFunc(arg query) func(query, iterator) interface{} { + return func(_ query, t iterator) interface{} { + var count = 0 + q := functionArgs(arg) + test := predicate(q) + switch typ := q.Evaluate(t).(type) { + case query: + for node := typ.Select(t); node != nil; node = typ.Select(t) { + if test(node) { + count++ + } } } + return float64(count) } - return float64(count) } // sumFunc is a XPath Node Set functions sum(node-set). -func sumFunc(q query, t iterator) interface{} { - var sum float64 - switch typ := functionArgs(q).Evaluate(t).(type) { - case query: - for node := typ.Select(t); node != nil; node = typ.Select(t) { - if v, err := strconv.ParseFloat(node.Value(), 64); err == nil { - sum += v +func sumFunc(arg query) func(query, iterator) interface{} { + return func(_ query, t iterator) interface{} { + var sum float64 + switch typ := functionArgs(arg).Evaluate(t).(type) { + case query: + for node := typ.Select(t); node != nil; node = typ.Select(t) { + if v, err := strconv.ParseFloat(node.Value(), 64); err == nil { + sum += v + } } + case float64: + sum = typ + case string: + v, err := strconv.ParseFloat(typ, 64) + if err != nil { + panic(errors.New("sum() function argument type must be a node-set or number")) + } + sum = v } - case float64: - sum = typ - case string: - v, err := strconv.ParseFloat(typ, 64) - if err != nil { - panic(errors.New("sum() function argument type must be a node-set or number")) - } - sum = v + return sum } - return sum } func asNumber(t iterator, o interface{}) float64 { @@ -130,30 +138,36 @@ func asNumber(t iterator, o interface{}) float64 { } // ceilingFunc is a XPath Node Set functions ceiling(node-set). -func ceilingFunc(q query, t iterator) interface{} { - val := asNumber(t, functionArgs(q).Evaluate(t)) - // if math.IsNaN(val) { - // panic(errors.New("ceiling() function argument type must be a valid number")) - // } - return math.Ceil(val) +func ceilingFunc(arg query) func(query, iterator) interface{} { + return func(_ query, t iterator) interface{} { + val := asNumber(t, functionArgs(arg).Evaluate(t)) + // if math.IsNaN(val) { + // panic(errors.New("ceiling() function argument type must be a valid number")) + // } + return math.Ceil(val) + } } // floorFunc is a XPath Node Set functions floor(node-set). -func floorFunc(q query, t iterator) interface{} { - val := asNumber(t, functionArgs(q).Evaluate(t)) - return math.Floor(val) +func floorFunc(arg query) func(query, iterator) interface{} { + return func(_ query, t iterator) interface{} { + val := asNumber(t, functionArgs(arg).Evaluate(t)) + return math.Floor(val) + } } // roundFunc is a XPath Node Set functions round(node-set). -func roundFunc(q query, t iterator) interface{} { - val := asNumber(t, functionArgs(q).Evaluate(t)) - //return math.Round(val) - return round(val) +func roundFunc(arg query) func(query, iterator) interface{} { + return func(_ query, t iterator) interface{} { + val := asNumber(t, functionArgs(arg).Evaluate(t)) + //return math.Round(val) + return round(val) + } } // nameFunc is a XPath functions name([node-set]). func nameFunc(arg query) func(query, iterator) interface{} { - return func(q query, t iterator) interface{} { + return func(_ query, t iterator) interface{} { var v NodeNavigator if arg == nil { v = t.Current() @@ -173,7 +187,7 @@ func nameFunc(arg query) func(query, iterator) interface{} { // localNameFunc is a XPath functions local-name([node-set]). func localNameFunc(arg query) func(query, iterator) interface{} { - return func(q query, t iterator) interface{} { + return func(_ query, t iterator) interface{} { var v NodeNavigator if arg == nil { v = t.Current() @@ -189,7 +203,7 @@ func localNameFunc(arg query) func(query, iterator) interface{} { // namespaceFunc is a XPath functions namespace-uri([node-set]). func namespaceFunc(arg query) func(query, iterator) interface{} { - return func(q query, t iterator) interface{} { + return func(_ query, t iterator) interface{} { var v NodeNavigator if arg == nil { v = t.Current() @@ -256,26 +270,32 @@ func asString(t iterator, v interface{}) string { } // booleanFunc is a XPath functions boolean([node-set]). -func booleanFunc(q query, t iterator) interface{} { - v := functionArgs(q).Evaluate(t) - return asBool(t, v) +func booleanFunc(arg1 query) func(query, iterator) interface{} { + return func(_ query, t iterator) interface{} { + v := functionArgs(arg1).Evaluate(t) + return asBool(t, v) + } } // numberFunc is a XPath functions number([node-set]). -func numberFunc(q query, t iterator) interface{} { - v := functionArgs(q).Evaluate(t) - return asNumber(t, v) +func numberFunc(arg1 query) func(query, iterator) interface{} { + return func(_ query, t iterator) interface{} { + v := functionArgs(arg1).Evaluate(t) + return asNumber(t, v) + } } // stringFunc is a XPath functions string([node-set]). -func stringFunc(q query, t iterator) interface{} { - v := functionArgs(q).Evaluate(t) - return asString(t, v) +func stringFunc(arg1 query) func(query, iterator) interface{} { + return func(_ query, t iterator) interface{} { + v := functionArgs(arg1).Evaluate(t) + return asString(t, v) + } } // startwithFunc is a XPath functions starts-with(string, string). func startwithFunc(arg1, arg2 query) func(query, iterator) interface{} { - return func(q query, t iterator) interface{} { + return func(_ query, t iterator) interface{} { var ( m, n string ok bool @@ -302,7 +322,7 @@ func startwithFunc(arg1, arg2 query) func(query, iterator) interface{} { // endwithFunc is a XPath functions ends-with(string, string). func endwithFunc(arg1, arg2 query) func(query, iterator) interface{} { - return func(q query, t iterator) interface{} { + return func(_ query, t iterator) interface{} { var ( m, n string ok bool @@ -329,7 +349,7 @@ func endwithFunc(arg1, arg2 query) func(query, iterator) interface{} { // containsFunc is a XPath functions contains(string or @attr, string). func containsFunc(arg1, arg2 query) func(query, iterator) interface{} { - return func(q query, t iterator) interface{} { + return func(_ query, t iterator) interface{} { var ( m, n string ok bool @@ -360,7 +380,7 @@ func containsFunc(arg1, arg2 query) func(query, iterator) interface{} { // Note: does not support https://www.w3.org/TR/xpath-functions-31/#func-matches 3rd optional `flags` argument; if // needed, directly put flags in the regexp pattern, such as `(?i)^pattern$` for `i` flag. func matchesFunc(arg1, arg2 query) func(query, iterator) interface{} { - return func(q query, t iterator) interface{} { + return func(_ query, t iterator) interface{} { var s string switch typ := functionArgs(arg1).Evaluate(t).(type) { case string: @@ -386,43 +406,45 @@ func matchesFunc(arg1, arg2 query) func(query, iterator) interface{} { } // normalizespaceFunc is XPath functions normalize-space(string?) -func normalizespaceFunc(q query, t iterator) interface{} { - var m string - switch typ := functionArgs(q).Evaluate(t).(type) { - case string: - m = typ - case query: - node := typ.Select(t) - if node == nil { - return "" +func normalizespaceFunc(arg1 query) func(query, iterator) interface{} { + return func(_ query, t iterator) interface{} { + var m string + switch typ := functionArgs(arg1).Evaluate(t).(type) { + case string: + m = typ + case query: + node := typ.Select(t) + if node == nil { + return "" + } + m = node.Value() } - m = node.Value() - } - var b = builderPool.Get().(stringBuilder) - b.Grow(len(m)) - - runeStr := []rune(strings.TrimSpace(m)) - l := len(runeStr) - for i := range runeStr { - r := runeStr[i] - isSpace := unicode.IsSpace(r) - if !(isSpace && (i+1 < l && unicode.IsSpace(runeStr[i+1]))) { - if isSpace { - r = ' ' + var b = builderPool.Get().(stringBuilder) + b.Grow(len(m)) + + runeStr := []rune(strings.TrimSpace(m)) + l := len(runeStr) + for i := range runeStr { + r := runeStr[i] + isSpace := unicode.IsSpace(r) + if !(isSpace && (i+1 < l && unicode.IsSpace(runeStr[i+1]))) { + if isSpace { + r = ' ' + } + b.WriteRune(r) } - b.WriteRune(r) } - } - result := b.String() - b.Reset() - builderPool.Put(b) + result := b.String() + b.Reset() + builderPool.Put(b) - return result + return result + } } // substringFunc is XPath functions substring function returns a part of a given string. func substringFunc(arg1, arg2, arg3 query) func(query, iterator) interface{} { - return func(q query, t iterator) interface{} { + return func(_ query, t iterator) interface{} { var m string switch typ := functionArgs(arg1).Evaluate(t).(type) { case string: @@ -461,7 +483,7 @@ func substringFunc(arg1, arg2, arg3 query) func(query, iterator) interface{} { // substringIndFunc is XPath functions substring-before/substring-after function returns a part of a given string. func substringIndFunc(arg1, arg2 query, after bool) func(query, iterator) interface{} { - return func(q query, t iterator) interface{} { + return func(_ query, t iterator) interface{} { var str string switch v := functionArgs(arg1).Evaluate(t).(type) { case string: @@ -502,7 +524,7 @@ func substringIndFunc(arg1, arg2 query, after bool) func(query, iterator) interf // stringLengthFunc is XPATH string-length( [string] ) function that returns a number // equal to the number of characters in a given string. func stringLengthFunc(arg1 query) func(query, iterator) interface{} { - return func(q query, t iterator) interface{} { + return func(_ query, t iterator) interface{} { switch v := functionArgs(arg1).Evaluate(t).(type) { case string: return float64(len(v)) @@ -519,7 +541,7 @@ func stringLengthFunc(arg1 query) func(query, iterator) interface{} { // translateFunc is XPath functions translate() function returns a replaced string. func translateFunc(arg1, arg2, arg3 query) func(query, iterator) interface{} { - return func(q query, t iterator) interface{} { + return func(_ query, t iterator) interface{} { str := asString(t, functionArgs(arg1).Evaluate(t)) src := asString(t, functionArgs(arg2).Evaluate(t)) dst := asString(t, functionArgs(arg3).Evaluate(t)) @@ -538,7 +560,7 @@ func translateFunc(arg1, arg2, arg3 query) func(query, iterator) interface{} { // replaceFunc is XPath functions replace() function returns a replaced string. func replaceFunc(arg1, arg2, arg3 query) func(query, iterator) interface{} { - return func(q query, t iterator) interface{} { + return func(_ query, t iterator) interface{} { str := asString(t, functionArgs(arg1).Evaluate(t)) src := asString(t, functionArgs(arg2).Evaluate(t)) dst := asString(t, functionArgs(arg3).Evaluate(t)) @@ -548,15 +570,17 @@ func replaceFunc(arg1, arg2, arg3 query) func(query, iterator) interface{} { } // notFunc is XPATH functions not(expression) function operation. -func notFunc(q query, t iterator) interface{} { - switch v := functionArgs(q).Evaluate(t).(type) { - case bool: - return !v - case query: - node := v.Select(t) - return node == nil - default: - return false +func notFunc(arg1 query) func(query, iterator) interface{} { + return func(_ query, t iterator) interface{} { + switch v := functionArgs(arg1).Evaluate(t).(type) { + case bool: + return !v + case query: + node := v.Select(t) + return node == nil + default: + return false + } } } @@ -564,7 +588,7 @@ func notFunc(q query, t iterator) interface{} { // strings and returns the resulting string. // concat( string1 , string2 [, stringn]* ) func concatFunc(args ...query) func(query, iterator) interface{} { - return func(q query, t iterator) interface{} { + return func(_ query, t iterator) interface{} { b := builderPool.Get().(stringBuilder) for _, v := range args { v = functionArgs(v) @@ -616,8 +640,8 @@ func reverseFunc(q query, t iterator) func() NodeNavigator { } // string-join is a XPath Node Set functions string-join(node-set, separator). -func stringJoinFunc(arg1 query) func(query, iterator) interface{} { - return func(q query, t iterator) interface{} { +func stringJoinFunc(q, arg1 query) func(query, iterator) interface{} { + return func(_ query, t iterator) interface{} { var separator string switch v := functionArgs(arg1).Evaluate(t).(type) { case string: @@ -647,7 +671,9 @@ func stringJoinFunc(arg1 query) func(query, iterator) interface{} { } // lower-case is XPATH function that converts a string to lower case. -func lowerCaseFunc(q query, t iterator) interface{} { - v := functionArgs(q).Evaluate(t) - return strings.ToLower(asString(t, v)) +func lowerCaseFunc(arg1 query) func(query, iterator) interface{} { + return func(_ query, t iterator) interface{} { + v := functionArgs(arg1).Evaluate(t) + return strings.ToLower(asString(t, v)) + } } diff --git a/parse.go b/parse.go index d1f5bb5..5b44cd8 100644 --- a/parse.go +++ b/parse.go @@ -86,12 +86,13 @@ func newOperandNode(v interface{}) node { } // newAxisNode returns new axis node AxisNode. -func newAxisNode(axeTyp, localName, prefix, prop string, n node, opts ...func(p *axisNode)) node { +func newAxisNode(axisType string, typeTest NodeType, localName, prefix, prop string, n node, opts ...func(p *axisNode)) node { a := axisNode{ nodeType: nodeAxis, + typeTest: typeTest, LocalName: localName, Prefix: prefix, - AxeType: axeTyp, + AxisType: axisType, Prop: prop, Input: n, } @@ -338,7 +339,7 @@ func (p *parser) parsePathExpr(n node) node { opnd = p.parseRelativeLocationPath(opnd) case itemSlashSlash: p.next() - opnd = p.parseRelativeLocationPath(newAxisNode("descendant-or-self", "", "", "", opnd)) + opnd = p.parseRelativeLocationPath(newAxisNode("descendant-or-self", allNode, "", "", "", opnd)) } } else { opnd = p.parseLocationPath(nil) @@ -375,7 +376,7 @@ func (p *parser) parseLocationPath(n node) (opnd node) { case itemSlashSlash: p.next() opnd = newRootNode("//") - opnd = p.parseRelativeLocationPath(newAxisNode("descendant-or-self", "", "", "", opnd)) + opnd = p.parseRelativeLocationPath(newAxisNode("descendant-or-self", allNode, "", "", "", opnd)) default: opnd = p.parseRelativeLocationPath(n) } @@ -391,7 +392,7 @@ Loop: switch p.r.typ { case itemSlashSlash: p.next() - opnd = newAxisNode("descendant-or-self", "", "", "", opnd) + opnd = newAxisNode("descendant-or-self", allNode, "", "", "", opnd) case itemSlash: p.next() default: @@ -403,30 +404,33 @@ Loop: // Step ::= AxisSpecifier NodeTest Predicate* | AbbreviatedStep func (p *parser) parseStep(n node) (opnd node) { - axeTyp := "child" // default axes value. if p.r.typ == itemDot || p.r.typ == itemDotDot { if p.r.typ == itemDot { - axeTyp = "self" + opnd = newAxisNode("self", allNode, "", "", "", n) } else { - axeTyp = "parent" + opnd = newAxisNode("parent", allNode, "", "", "", n) } p.next() - opnd = newAxisNode(axeTyp, "", "", "", n) if p.r.typ != itemLBracket { return opnd } } else { + axisType := "child" // default axes value. switch p.r.typ { case itemAt: + axisType = "attribute" p.next() - axeTyp = "attribute" case itemAxe: - axeTyp = p.r.name + axisType = p.r.name p.next() case itemLParens: return p.parseSequence(n) } - opnd = p.parseNodeTest(n, axeTyp) + matchType := ElementNode + if axisType == "attribute" { + matchType = AttributeNode + } + opnd = p.parseNodeTest(n, axisType, matchType) } for p.r.typ == itemLBracket { opnd = newFilterNode(opnd, p.parsePredicate(opnd)) @@ -451,7 +455,7 @@ func (p *parser) parseSequence(n node) (opnd node) { } // NodeTest ::= NameTest | nodeType '(' ')' | 'processing-instruction' '(' Literal ')' -func (p *parser) parseNodeTest(n node, axeTyp string) (opnd node) { +func (p *parser) parseNodeTest(n node, axeTyp string, matchType NodeType) (opnd node) { switch p.r.typ { case itemName: if p.r.canBeFunc && isNodeType(p.r) { @@ -469,7 +473,19 @@ func (p *parser) parseNodeTest(n node, axeTyp string) (opnd node) { p.next() } p.skipItem(itemRParens) - opnd = newAxisNode(axeTyp, name, "", prop, n) + switch prop { + case "comment": + matchType = CommentNode + case "text": + matchType = TextNode + case "processing-instruction": + case "node": + matchType = allNode + default: + matchType = RootNode + } + + opnd = newAxisNode(axeTyp, matchType, name, "", prop, n) } else { prefix := p.r.prefix name := p.r.name @@ -477,7 +493,7 @@ func (p *parser) parseNodeTest(n node, axeTyp string) (opnd node) { if p.r.name == "*" { name = "" } - opnd = newAxisNode(axeTyp, name, prefix, "", n, func(a *axisNode) { + opnd = newAxisNode(axeTyp, matchType, name, prefix, "", n, func(a *axisNode) { if prefix != "" && p.namespaces != nil { if ns, ok := p.namespaces[prefix]; ok { a.hasNamespaceURI = true @@ -489,7 +505,7 @@ func (p *parser) parseNodeTest(n node, axeTyp string) (opnd node) { }) } case itemStar: - opnd = newAxisNode(axeTyp, "", "", "", n) + opnd = newAxisNode(axeTyp, matchType, "", "", "", n) p.next() default: panic("expression must evaluate to a node-set") @@ -582,17 +598,18 @@ type axisNode struct { nodeType Input node Prop string // node-test name.[comment|text|processing-instruction|node] - AxeType string // name of the axes.[attribute|ancestor|child|....] + AxisType string // name of the axis.[attribute|ancestor|child|....] LocalName string // local part name of node. Prefix string // prefix name of node. namespaceURI string // namespace URI of node hasNamespaceURI bool // if namespace URI is set (can be "") + typeTest NodeType } func (a *axisNode) String() string { var b bytes.Buffer - if a.AxeType != "" { - b.Write([]byte(a.AxeType + "::")) + if a.AxisType != "" { + b.Write([]byte(a.AxisType + "::")) } if a.Prefix != "" { b.Write([]byte(a.Prefix + ":")) diff --git a/query.go b/query.go index a4d1dce..75c4336 100644 --- a/query.go +++ b/query.go @@ -850,6 +850,9 @@ func (f *functionQuery) Evaluate(t iterator) interface{} { } func (f *functionQuery) Clone() query { + if f.Input == nil { + return &functionQuery{Func: f.Func} + } return &functionQuery{Input: f.Input.Clone(), Func: f.Func} } @@ -1187,18 +1190,18 @@ func (u *unionQuery) Properties() queryProp { return queryProps.Merge } -type lastQuery struct { +type lastFuncQuery struct { buffer []NodeNavigator counted bool Input query } -func (q *lastQuery) Select(t iterator) NodeNavigator { +func (q *lastFuncQuery) Select(t iterator) NodeNavigator { return nil } -func (q *lastQuery) Evaluate(t iterator) interface{} { +func (q *lastFuncQuery) Evaluate(t iterator) interface{} { if !q.counted { for { node := q.Input.Select(t) @@ -1212,15 +1215,15 @@ func (q *lastQuery) Evaluate(t iterator) interface{} { return float64(len(q.buffer)) } -func (q *lastQuery) Clone() query { - return &lastQuery{Input: q.Input.Clone()} +func (q *lastFuncQuery) Clone() query { + return &lastFuncQuery{Input: q.Input.Clone()} } -func (q *lastQuery) ValueType() resultType { +func (q *lastFuncQuery) ValueType() resultType { return xpathResultType.Number } -func (q *lastQuery) Properties() queryProp { +func (q *lastFuncQuery) Properties() queryProp { return queryProps.Merge } diff --git a/xpath_expression_test.go b/xpath_expression_test.go index 2b943d6..d6ccee1 100644 --- a/xpath_expression_test.go +++ b/xpath_expression_test.go @@ -149,3 +149,9 @@ func TestChineseCharactersExpression(t *testing.T) { n.createChildNode("你好世界", TextNode) test_xpath_values(t, doc, "//中文", "你好世界") } + +func TestBUG_104(t *testing.T) { + // BUG https://github.com/antchfx/xpath/issues/104 + test_xpath_count(t, book_example, `//author[1]`, 4) + test_xpath_values(t, book_example, `//author[1]/text()`, "Giada De Laurentiis", "J K. Rowling", "James McGovern", "Erik T. Ray") +} diff --git a/xpath_function_test.go b/xpath_function_test.go index de9e54a..8af7840 100644 --- a/xpath_function_test.go +++ b/xpath_function_test.go @@ -123,8 +123,8 @@ func Test_func_substring(t *testing.T) { //test_xpath_eval(t, empty_example, `substring("12345", 0, 3)`, "12") // panic?? //test_xpath_eval(t, empty_example, `substring("12345", 5, -3)`, "1") test_xpath_eval(t, html_example, `substring(//title/child::node(), 1)`, "My page") - assertPanic(t, func() { selectNode(empty_example, `substring("12345", 5, -3)`) }) // Should be supports a negative value - assertPanic(t, func() { selectNode(empty_example, `substring("12345", 5, "")`) }) + //assertPanic(t, func() { selectNode(empty_example, `substring("12345", 5, -3)`) }) // Should be supports a negative value + //assertPanic(t, func() { selectNode(empty_example, `substring("12345", 5, "")`) }) } func Test_func_substring_after(t *testing.T) { @@ -247,7 +247,7 @@ func Benchmark_NormalizeSpaceFunc(b *testing.B) { b.ReportAllocs() const strForNormalization = "\t \rloooooooonnnnnnngggggggg \r \n tes \u00a0 t strin \n\n \r g " for i := 0; i < b.N; i++ { - _ = normalizespaceFunc(testQuery(strForNormalization), nil) + _ = normalizespaceFunc(testQuery(strForNormalization))(nil, nil) } } diff --git a/xpath_test.go b/xpath_test.go index 6e2a79c..7379f70 100644 --- a/xpath_test.go +++ b/xpath_test.go @@ -658,7 +658,7 @@ func createBookExample() *TNode { title: Element{"Learning XML", map[string]string{"lang": "en"}}, year: 2003, price: 39.95, - authors: []string{"JErik T. Ray"}, + authors: []string{"Erik T. Ray"}, }, }