diff --git a/lib.go b/lib.go index f592fbe..0710a7a 100644 --- a/lib.go +++ b/lib.go @@ -16,6 +16,7 @@ import ( "bytes" "container/list" "encoding/xml" + "errors" "fmt" "io" "io/ioutil" @@ -398,7 +399,7 @@ func boolPtr(b bool) *bool { return &b } // intPtr returns a pointer to a int with the given value. func intPtr(i int) *int { return &i } -// float64Ptr returns a pofloat64er to a float64 with the given value. +// float64Ptr returns a pointer to a float64 with the given value. func float64Ptr(f float64) *float64 { return &f } // stringPtr returns a pointer to a string with the given value. @@ -412,6 +413,66 @@ func defaultTrue(b *bool) bool { return *b } +// MarshalXMLMarshalXML convert the boolean data type to literal values 0 or 1 +// on serialization. +func (avb attrValBool) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + attr := xml.Attr{ + Name: xml.Name{ + Space: start.Name.Space, + Local: "val", + }, + Value: "0", + } + if avb.Val != nil { + if *avb.Val { + attr.Value = "1" + } else { + attr.Value = "0" + } + } + start.Attr = []xml.Attr{attr} + e.EncodeToken(start) + e.EncodeToken(start.End()) + return nil +} + +// UnmarshalXML convert the literal values true, false, 1, 0 of the XML +// attribute to boolean data type on de-serialization. +func (avb *attrValBool) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { + for { + t, err := d.Token() + if err != nil { + return err + } + found := false + switch t.(type) { + case xml.StartElement: + return errors.New("unexpected child of attrValBool") + case xml.EndElement: + found = true + } + if found { + break + } + } + for _, attr := range start.Attr { + if attr.Name.Local == "val" { + if attr.Value == "" { + val := true + avb.Val = &val + } else { + val, err := strconv.ParseBool(attr.Value) + if err != nil { + return err + } + avb.Val = &val + } + return nil + } + } + return nil +} + // parseFormatSet provides a method to convert format string to []byte and // handle empty string. func parseFormatSet(formatSet string) []byte { diff --git a/lib_test.go b/lib_test.go index 84a52bb..35dd2a0 100644 --- a/lib_test.go +++ b/lib_test.go @@ -5,6 +5,7 @@ import ( "bytes" "encoding/xml" "fmt" + "io" "os" "strconv" "strings" @@ -237,6 +238,36 @@ func TestInStrSlice(t *testing.T) { assert.EqualValues(t, -1, inStrSlice([]string{}, "")) } +func TestBoolValMarshal(t *testing.T) { + bold := true + node := &xlsxFont{B: &attrValBool{Val: &bold}} + data, err := xml.Marshal(node) + assert.NoError(t, err) + assert.Equal(t, ``, string(data)) + + node = &xlsxFont{} + err = xml.Unmarshal(data, node) + assert.NoError(t, err) + assert.NotEqual(t, nil, node) + assert.NotEqual(t, nil, node.B) + assert.NotEqual(t, nil, node.B.Val) + assert.Equal(t, true, *node.B.Val) +} + +func TestBoolValUnmarshalXML(t *testing.T) { + node := xlsxFont{} + assert.NoError(t, xml.Unmarshal([]byte(""), &node)) + assert.Equal(t, true, *node.B.Val) + for content, err := range map[string]string{ + "": "unexpected child of attrValBool", + "": "strconv.ParseBool: parsing \"x\": invalid syntax", + } { + assert.EqualError(t, xml.Unmarshal([]byte(content), &node), err) + } + attr := attrValBool{} + assert.EqualError(t, attr.UnmarshalXML(xml.NewDecoder(strings.NewReader("")), xml.StartElement{}), io.EOF.Error()) +} + func TestBytesReplace(t *testing.T) { s := []byte{0x01} assert.EqualValues(t, s, bytesReplace(s, []byte{}, []byte{}, 0))