diff --git a/lib.go b/lib.go
index f592fbe..0710a7a 100644
--- a/lib.go
+++ b/lib.go
@@ -16,6 +16,7 @@ import (
"bytes"
"container/list"
"encoding/xml"
+ "errors"
"fmt"
"io"
"io/ioutil"
@@ -398,7 +399,7 @@ func boolPtr(b bool) *bool { return &b }
// intPtr returns a pointer to a int with the given value.
func intPtr(i int) *int { return &i }
-// float64Ptr returns a pofloat64er to a float64 with the given value.
+// float64Ptr returns a pointer to a float64 with the given value.
func float64Ptr(f float64) *float64 { return &f }
// stringPtr returns a pointer to a string with the given value.
@@ -412,6 +413,66 @@ func defaultTrue(b *bool) bool {
return *b
}
+// MarshalXMLMarshalXML convert the boolean data type to literal values 0 or 1
+// on serialization.
+func (avb attrValBool) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
+ attr := xml.Attr{
+ Name: xml.Name{
+ Space: start.Name.Space,
+ Local: "val",
+ },
+ Value: "0",
+ }
+ if avb.Val != nil {
+ if *avb.Val {
+ attr.Value = "1"
+ } else {
+ attr.Value = "0"
+ }
+ }
+ start.Attr = []xml.Attr{attr}
+ e.EncodeToken(start)
+ e.EncodeToken(start.End())
+ return nil
+}
+
+// UnmarshalXML convert the literal values true, false, 1, 0 of the XML
+// attribute to boolean data type on de-serialization.
+func (avb *attrValBool) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
+ for {
+ t, err := d.Token()
+ if err != nil {
+ return err
+ }
+ found := false
+ switch t.(type) {
+ case xml.StartElement:
+ return errors.New("unexpected child of attrValBool")
+ case xml.EndElement:
+ found = true
+ }
+ if found {
+ break
+ }
+ }
+ for _, attr := range start.Attr {
+ if attr.Name.Local == "val" {
+ if attr.Value == "" {
+ val := true
+ avb.Val = &val
+ } else {
+ val, err := strconv.ParseBool(attr.Value)
+ if err != nil {
+ return err
+ }
+ avb.Val = &val
+ }
+ return nil
+ }
+ }
+ return nil
+}
+
// parseFormatSet provides a method to convert format string to []byte and
// handle empty string.
func parseFormatSet(formatSet string) []byte {
diff --git a/lib_test.go b/lib_test.go
index 84a52bb..35dd2a0 100644
--- a/lib_test.go
+++ b/lib_test.go
@@ -5,6 +5,7 @@ import (
"bytes"
"encoding/xml"
"fmt"
+ "io"
"os"
"strconv"
"strings"
@@ -237,6 +238,36 @@ func TestInStrSlice(t *testing.T) {
assert.EqualValues(t, -1, inStrSlice([]string{}, ""))
}
+func TestBoolValMarshal(t *testing.T) {
+ bold := true
+ node := &xlsxFont{B: &attrValBool{Val: &bold}}
+ data, err := xml.Marshal(node)
+ assert.NoError(t, err)
+ assert.Equal(t, ``, string(data))
+
+ node = &xlsxFont{}
+ err = xml.Unmarshal(data, node)
+ assert.NoError(t, err)
+ assert.NotEqual(t, nil, node)
+ assert.NotEqual(t, nil, node.B)
+ assert.NotEqual(t, nil, node.B.Val)
+ assert.Equal(t, true, *node.B.Val)
+}
+
+func TestBoolValUnmarshalXML(t *testing.T) {
+ node := xlsxFont{}
+ assert.NoError(t, xml.Unmarshal([]byte(""), &node))
+ assert.Equal(t, true, *node.B.Val)
+ for content, err := range map[string]string{
+ "": "unexpected child of attrValBool",
+ "": "strconv.ParseBool: parsing \"x\": invalid syntax",
+ } {
+ assert.EqualError(t, xml.Unmarshal([]byte(content), &node), err)
+ }
+ attr := attrValBool{}
+ assert.EqualError(t, attr.UnmarshalXML(xml.NewDecoder(strings.NewReader("")), xml.StartElement{}), io.EOF.Error())
+}
+
func TestBytesReplace(t *testing.T) {
s := []byte{0x01}
assert.EqualValues(t, s, bytesReplace(s, []byte{}, []byte{}, 0))