feat: implement SHA-512 algorithm to ProtectSheet (#1115)

pull/2/head
Jonham.Chen 3 years ago committed by GitHub
parent 9e64df6a96
commit af5c4d00e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -980,12 +980,12 @@ func (f *File) getFormatChart(format string, combo []string) (*formatChart, []*f
return formatSet, comboCharts, err return formatSet, comboCharts, err
} }
if _, ok := chartValAxNumFmtFormatCode[comboChart.Type]; !ok { if _, ok := chartValAxNumFmtFormatCode[comboChart.Type]; !ok {
return formatSet, comboCharts, newUnsupportChartType(comboChart.Type) return formatSet, comboCharts, newUnsupportedChartType(comboChart.Type)
} }
comboCharts = append(comboCharts, comboChart) comboCharts = append(comboCharts, comboChart)
} }
if _, ok := chartValAxNumFmtFormatCode[formatSet.Type]; !ok { if _, ok := chartValAxNumFmtFormatCode[formatSet.Type]; !ok {
return formatSet, comboCharts, newUnsupportChartType(formatSet.Type) return formatSet, comboCharts, newUnsupportedChartType(formatSet.Type)
} }
return formatSet, comboCharts, err return formatSet, comboCharts, err
} }

@ -43,6 +43,7 @@ var (
packageOffset = 8 // First 8 bytes are the size of the stream packageOffset = 8 // First 8 bytes are the size of the stream
packageEncryptionChunkSize = 4096 packageEncryptionChunkSize = 4096
iterCount = 50000 iterCount = 50000
sheetProtectionSpinCount = 1e5
oleIdentifier = []byte{ oleIdentifier = []byte{
0xd0, 0xcf, 0x11, 0xe0, 0xa1, 0xb1, 0x1a, 0xe1, 0xd0, 0xcf, 0x11, 0xe0, 0xa1, 0xb1, 0x1a, 0xe1,
} }
@ -146,7 +147,7 @@ func Decrypt(raw []byte, opt *Options) (packageBuf []byte, err error) {
case "standard": case "standard":
return standardDecrypt(encryptionInfoBuf, encryptedPackageBuf, opt) return standardDecrypt(encryptionInfoBuf, encryptedPackageBuf, opt)
default: default:
err = ErrUnsupportEncryptMechanism err = ErrUnsupportedEncryptMechanism
} }
return return
} }
@ -307,7 +308,7 @@ func encryptionMechanism(buffer []byte) (mechanism string, err error) {
} else if (versionMajor == 3 || versionMajor == 4) && versionMinor == 3 { } else if (versionMajor == 3 || versionMajor == 4) && versionMinor == 3 {
mechanism = "extensible" mechanism = "extensible"
} }
err = ErrUnsupportEncryptMechanism err = ErrUnsupportedEncryptMechanism
return return
} }
@ -387,14 +388,14 @@ func standardConvertPasswdToKey(header StandardEncryptionHeader, verifier Standa
key = hashing("sha1", iterator, key) key = hashing("sha1", iterator, key)
} }
var block int var block int
hfinal := hashing("sha1", key, createUInt32LEBuffer(block, 4)) hFinal := hashing("sha1", key, createUInt32LEBuffer(block, 4))
cbRequiredKeyLength := int(header.KeySize) / 8 cbRequiredKeyLength := int(header.KeySize) / 8
cbHash := sha1.Size cbHash := sha1.Size
buf1 := bytes.Repeat([]byte{0x36}, 64) buf1 := bytes.Repeat([]byte{0x36}, 64)
buf1 = append(standardXORBytes(hfinal, buf1[:cbHash]), buf1[cbHash:]...) buf1 = append(standardXORBytes(hFinal, buf1[:cbHash]), buf1[cbHash:]...)
x1 := hashing("sha1", buf1) x1 := hashing("sha1", buf1)
buf2 := bytes.Repeat([]byte{0x5c}, 64) buf2 := bytes.Repeat([]byte{0x5c}, 64)
buf2 = append(standardXORBytes(hfinal, buf2[:cbHash]), buf2[cbHash:]...) buf2 = append(standardXORBytes(hFinal, buf2[:cbHash]), buf2[cbHash:]...)
x2 := hashing("sha1", buf2) x2 := hashing("sha1", buf2)
x3 := append(x1, x2...) x3 := append(x1, x2...)
keyDerived := x3[:cbRequiredKeyLength] keyDerived := x3[:cbRequiredKeyLength]
@ -417,7 +418,8 @@ func standardXORBytes(a, b []byte) []byte {
// ECMA-376 Agile Encryption // ECMA-376 Agile Encryption
// agileDecrypt decrypt the CFB file format with ECMA-376 agile encryption. // agileDecrypt decrypt the CFB file format with ECMA-376 agile encryption.
// Support cryptographic algorithm: MD4, MD5, RIPEMD-160, SHA1, SHA256, SHA384 and SHA512. // Support cryptographic algorithm: MD4, MD5, RIPEMD-160, SHA1, SHA256,
// SHA384 and SHA512.
func agileDecrypt(encryptionInfoBuf, encryptedPackageBuf []byte, opt *Options) (packageBuf []byte, err error) { func agileDecrypt(encryptionInfoBuf, encryptedPackageBuf []byte, opt *Options) (packageBuf []byte, err error) {
var encryptionInfo Encryption var encryptionInfo Encryption
if encryptionInfo, err = parseEncryptionInfo(encryptionInfoBuf[8:]); err != nil { if encryptionInfo, err = parseEncryptionInfo(encryptionInfoBuf[8:]); err != nil {
@ -605,11 +607,55 @@ func createIV(blockKey interface{}, encryption Encryption) ([]byte, error) {
return iv, nil return iv, nil
} }
// randomBytes returns securely generated random bytes. It will return an error if the system's // randomBytes returns securely generated random bytes. It will return an
// secure random number generator fails to function correctly, in which case the caller should not // error if the system's secure random number generator fails to function
// continue. // correctly, in which case the caller should not continue.
func randomBytes(n int) ([]byte, error) { func randomBytes(n int) ([]byte, error) {
b := make([]byte, n) b := make([]byte, n)
_, err := rand.Read(b) _, err := rand.Read(b)
return b, err return b, err
} }
// ISO Write Protection Method
// genISOPasswdHash implements the ISO password hashing algorithm by given
// plaintext password, name of the cryptographic hash algorithm, salt value
// and spin count.
func genISOPasswdHash(passwd, hashAlgorithm, salt string, spinCount int) (hashValue, saltValue string, err error) {
if len(passwd) < 1 || len(passwd) > MaxFieldLength {
err = ErrPasswordLengthInvalid
return
}
hash, ok := map[string]string{
"MD4": "md4",
"MD5": "md5",
"SHA-1": "sha1",
"SHA-256": "sha256",
"SHA-384": "sha384",
"SHA-512": "sha512",
}[hashAlgorithm]
if !ok {
err = ErrUnsupportedHashAlgorithm
return
}
var b bytes.Buffer
s, _ := randomBytes(16)
if salt != "" {
if s, err = base64.StdEncoding.DecodeString(salt); err != nil {
return
}
}
b.Write(s)
encoder := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM).NewEncoder()
passwordBuffer, _ := encoder.Bytes([]byte(passwd))
b.Write(passwordBuffer)
// Generate the initial hash.
key := hashing(hash, b.Bytes())
// Now regenerate until spin count.
for i := 0; i < spinCount; i++ {
iterator := createUInt32LEBuffer(i, 4)
key = hashing(hash, key, iterator)
}
hashValue, saltValue = base64.StdEncoding.EncodeToString(key), base64.StdEncoding.EncodeToString(s)
return
}

@ -28,11 +28,27 @@ func TestEncrypt(t *testing.T) {
func TestEncryptionMechanism(t *testing.T) { func TestEncryptionMechanism(t *testing.T) {
mechanism, err := encryptionMechanism([]byte{3, 0, 3, 0}) mechanism, err := encryptionMechanism([]byte{3, 0, 3, 0})
assert.Equal(t, mechanism, "extensible") assert.Equal(t, mechanism, "extensible")
assert.EqualError(t, err, ErrUnsupportEncryptMechanism.Error()) assert.EqualError(t, err, ErrUnsupportedEncryptMechanism.Error())
_, err = encryptionMechanism([]byte{}) _, err = encryptionMechanism([]byte{})
assert.EqualError(t, err, ErrUnknownEncryptMechanism.Error()) assert.EqualError(t, err, ErrUnknownEncryptMechanism.Error())
} }
func TestHashing(t *testing.T) { func TestHashing(t *testing.T) {
assert.Equal(t, hashing("unsupportHashAlgorithm", []byte{}), []uint8([]byte(nil))) assert.Equal(t, hashing("unsupportedHashAlgorithm", []byte{}), []uint8([]byte(nil)))
}
func TestGenISOPasswdHash(t *testing.T) {
for hashAlgorithm, expected := range map[string][]string{
"MD4": {"2lZQZUubVHLm/t6KsuHX4w==", "TTHjJdU70B/6Zq83XGhHVA=="},
"MD5": {"HWbqyd4dKKCjk1fEhk2kuQ==", "8ADyorkumWCayIukRhlVKQ=="},
"SHA-1": {"XErQIV3Ol+nhXkyCxrLTEQm+mSc=", "I3nDtyf59ASaNX1l6KpFnA=="},
"SHA-256": {"7oqMFyfED+mPrzRIBQ+KpKT4SClMHEPOZldliP15xAA=", "ru1R/w3P3Jna2Qo+EE8QiA=="},
"SHA-384": {"nMODLlxsC8vr0btcq0kp/jksg5FaI3az5Sjo1yZk+/x4bFzsuIvpDKUhJGAk/fzo", "Zjq9/jHlgOY6MzFDSlVNZg=="},
"SHA-512": {"YZ6jrGOFQgVKK3rDK/0SHGGgxEmFJglQIIRamZc2PkxVtUBp54fQn96+jVXEOqo6dtCSanqksXGcm/h3KaiR4Q==", "p5s/bybHBPtusI7EydTIrg=="},
} {
hashValue, saltValue, err := genISOPasswdHash("password", hashAlgorithm, expected[1], int(sheetProtectionSpinCount))
assert.NoError(t, err)
assert.Equal(t, expected[0], hashValue)
assert.Equal(t, expected[1], saltValue)
}
} }

@ -29,7 +29,7 @@ const (
DataValidationTypeDate DataValidationTypeDate
DataValidationTypeDecimal DataValidationTypeDecimal
typeList // inline use typeList // inline use
DataValidationTypeTextLeng DataValidationTypeTextLength
DataValidationTypeTime DataValidationTypeTime
// DataValidationTypeWhole Integer // DataValidationTypeWhole Integer
DataValidationTypeWhole DataValidationTypeWhole
@ -116,7 +116,7 @@ func (dd *DataValidation) SetInput(title, msg string) {
func (dd *DataValidation) SetDropList(keys []string) error { func (dd *DataValidation) SetDropList(keys []string) error {
formula := strings.Join(keys, ",") formula := strings.Join(keys, ",")
if MaxFieldLength < len(utf16.Encode([]rune(formula))) { if MaxFieldLength < len(utf16.Encode([]rune(formula))) {
return ErrDataValidationFormulaLenth return ErrDataValidationFormulaLength
} }
dd.Formula1 = fmt.Sprintf(`<formula1>"%s"</formula1>`, formulaEscaper.Replace(formula)) dd.Formula1 = fmt.Sprintf(`<formula1>"%s"</formula1>`, formulaEscaper.Replace(formula))
dd.Type = convDataValidationType(typeList) dd.Type = convDataValidationType(typeList)
@ -155,7 +155,7 @@ func (dd *DataValidation) SetRange(f1, f2 interface{}, t DataValidationType, o D
} }
dd.Formula1, dd.Formula2 = formula1, formula2 dd.Formula1, dd.Formula2 = formula1, formula2
dd.Type = convDataValidationType(t) dd.Type = convDataValidationType(t)
dd.Operator = convDataValidationOperatior(o) dd.Operator = convDataValidationOperator(o)
return nil return nil
} }
@ -192,22 +192,22 @@ func (dd *DataValidation) SetSqref(sqref string) {
// convDataValidationType get excel data validation type. // convDataValidationType get excel data validation type.
func convDataValidationType(t DataValidationType) string { func convDataValidationType(t DataValidationType) string {
typeMap := map[DataValidationType]string{ typeMap := map[DataValidationType]string{
typeNone: "none", typeNone: "none",
DataValidationTypeCustom: "custom", DataValidationTypeCustom: "custom",
DataValidationTypeDate: "date", DataValidationTypeDate: "date",
DataValidationTypeDecimal: "decimal", DataValidationTypeDecimal: "decimal",
typeList: "list", typeList: "list",
DataValidationTypeTextLeng: "textLength", DataValidationTypeTextLength: "textLength",
DataValidationTypeTime: "time", DataValidationTypeTime: "time",
DataValidationTypeWhole: "whole", DataValidationTypeWhole: "whole",
} }
return typeMap[t] return typeMap[t]
} }
// convDataValidationOperatior get excel data validation operator. // convDataValidationOperator get excel data validation operator.
func convDataValidationOperatior(o DataValidationOperator) string { func convDataValidationOperator(o DataValidationOperator) string {
typeMap := map[DataValidationOperator]string{ typeMap := map[DataValidationOperator]string{
DataValidationOperatorBetween: "between", DataValidationOperatorBetween: "between",
DataValidationOperatorEqual: "equal", DataValidationOperatorEqual: "equal",

@ -94,7 +94,7 @@ func TestDataValidationError(t *testing.T) {
t.Errorf("data validation error. Formula1 must be empty!") t.Errorf("data validation error. Formula1 must be empty!")
return return
} }
assert.EqualError(t, err, ErrDataValidationFormulaLenth.Error()) assert.EqualError(t, err, ErrDataValidationFormulaLength.Error())
assert.EqualError(t, dvRange.SetRange(nil, 20, DataValidationTypeWhole, DataValidationOperatorBetween), ErrParameterInvalid.Error()) assert.EqualError(t, dvRange.SetRange(nil, 20, DataValidationTypeWhole, DataValidationOperatorBetween), ErrParameterInvalid.Error())
assert.EqualError(t, dvRange.SetRange(10, nil, DataValidationTypeWhole, DataValidationOperatorBetween), ErrParameterInvalid.Error()) assert.EqualError(t, dvRange.SetRange(10, nil, DataValidationTypeWhole, DataValidationOperatorBetween), ErrParameterInvalid.Error())
assert.NoError(t, dvRange.SetRange(10, 20, DataValidationTypeWhole, DataValidationOperatorGreaterThan)) assert.NoError(t, dvRange.SetRange(10, 20, DataValidationTypeWhole, DataValidationOperatorGreaterThan))
@ -114,7 +114,7 @@ func TestDataValidationError(t *testing.T) {
err = dvRange.SetDropList(keys) err = dvRange.SetDropList(keys)
assert.Equal(t, prevFormula1, dvRange.Formula1, assert.Equal(t, prevFormula1, dvRange.Formula1,
"Formula1 should be unchanged for invalid input %v", keys) "Formula1 should be unchanged for invalid input %v", keys)
assert.EqualError(t, err, ErrDataValidationFormulaLenth.Error()) assert.EqualError(t, err, ErrDataValidationFormulaLength.Error())
} }
assert.NoError(t, f.AddDataValidation("Sheet1", dvRange)) assert.NoError(t, f.AddDataValidation("Sheet1", dvRange))
assert.NoError(t, dvRange.SetRange( assert.NoError(t, dvRange.SetRange(

@ -16,42 +16,50 @@ import (
"fmt" "fmt"
) )
// newInvalidColumnNameError defined the error message on receiving the invalid column name. // newInvalidColumnNameError defined the error message on receiving the
// invalid column name.
func newInvalidColumnNameError(col string) error { func newInvalidColumnNameError(col string) error {
return fmt.Errorf("invalid column name %q", col) return fmt.Errorf("invalid column name %q", col)
} }
// newInvalidRowNumberError defined the error message on receiving the invalid row number. // newInvalidRowNumberError defined the error message on receiving the invalid
// row number.
func newInvalidRowNumberError(row int) error { func newInvalidRowNumberError(row int) error {
return fmt.Errorf("invalid row number %d", row) return fmt.Errorf("invalid row number %d", row)
} }
// newInvalidCellNameError defined the error message on receiving the invalid cell name. // newInvalidCellNameError defined the error message on receiving the invalid
// cell name.
func newInvalidCellNameError(cell string) error { func newInvalidCellNameError(cell string) error {
return fmt.Errorf("invalid cell name %q", cell) return fmt.Errorf("invalid cell name %q", cell)
} }
// newInvalidExcelDateError defined the error message on receiving the data with negative values. // newInvalidExcelDateError defined the error message on receiving the data
// with negative values.
func newInvalidExcelDateError(dateValue float64) error { func newInvalidExcelDateError(dateValue float64) error {
return fmt.Errorf("invalid date value %f, negative values are not supported", dateValue) return fmt.Errorf("invalid date value %f, negative values are not supported", dateValue)
} }
// newUnsupportChartType defined the error message on receiving the chart type are unsupported. // newUnsupportedChartType defined the error message on receiving the chart
func newUnsupportChartType(chartType string) error { // type are unsupported.
func newUnsupportedChartType(chartType string) error {
return fmt.Errorf("unsupported chart type %s", chartType) return fmt.Errorf("unsupported chart type %s", chartType)
} }
// newUnzipSizeLimitError defined the error message on unzip size exceeds the limit. // newUnzipSizeLimitError defined the error message on unzip size exceeds the
// limit.
func newUnzipSizeLimitError(unzipSizeLimit int64) error { func newUnzipSizeLimitError(unzipSizeLimit int64) error {
return fmt.Errorf("unzip size exceeds the %d bytes limit", unzipSizeLimit) return fmt.Errorf("unzip size exceeds the %d bytes limit", unzipSizeLimit)
} }
// newInvalidStyleID defined the error message on receiving the invalid style ID. // newInvalidStyleID defined the error message on receiving the invalid style
// ID.
func newInvalidStyleID(styleID int) error { func newInvalidStyleID(styleID int) error {
return fmt.Errorf("invalid style ID %d, negative values are not supported", styleID) return fmt.Errorf("invalid style ID %d, negative values are not supported", styleID)
} }
// newFieldLengthError defined the error message on receiving the field length overflow. // newFieldLengthError defined the error message on receiving the field length
// overflow.
func newFieldLengthError(name string) error { func newFieldLengthError(name string) error {
return fmt.Errorf("field %s must be less or equal than 255 characters", name) return fmt.Errorf("field %s must be less or equal than 255 characters", name)
} }
@ -103,12 +111,18 @@ var (
ErrMaxFileNameLength = errors.New("file name length exceeds maximum limit") ErrMaxFileNameLength = errors.New("file name length exceeds maximum limit")
// ErrEncrypt defined the error message on encryption spreadsheet. // ErrEncrypt defined the error message on encryption spreadsheet.
ErrEncrypt = errors.New("not support encryption currently") ErrEncrypt = errors.New("not support encryption currently")
// ErrUnknownEncryptMechanism defined the error message on unsupport // ErrUnknownEncryptMechanism defined the error message on unsupported
// encryption mechanism. // encryption mechanism.
ErrUnknownEncryptMechanism = errors.New("unknown encryption mechanism") ErrUnknownEncryptMechanism = errors.New("unknown encryption mechanism")
// ErrUnsupportEncryptMechanism defined the error message on unsupport // ErrUnsupportedEncryptMechanism defined the error message on unsupported
// encryption mechanism. // encryption mechanism.
ErrUnsupportEncryptMechanism = errors.New("unsupport encryption mechanism") ErrUnsupportedEncryptMechanism = errors.New("unsupported encryption mechanism")
// ErrUnsupportedHashAlgorithm defined the error message on unsupported
// hash algorithm.
ErrUnsupportedHashAlgorithm = errors.New("unsupported hash algorithm")
// ErrPasswordLengthInvalid defined the error message on invalid password
// length.
ErrPasswordLengthInvalid = errors.New("password length invalid")
// ErrParameterRequired defined the error message on receive the empty // ErrParameterRequired defined the error message on receive the empty
// parameter. // parameter.
ErrParameterRequired = errors.New("parameter is required") ErrParameterRequired = errors.New("parameter is required")
@ -131,11 +145,17 @@ var (
// ErrSheetIdx defined the error message on receive the invalid worksheet // ErrSheetIdx defined the error message on receive the invalid worksheet
// index. // index.
ErrSheetIdx = errors.New("invalid worksheet index") ErrSheetIdx = errors.New("invalid worksheet index")
// ErrUnprotectSheet defined the error message on worksheet has set no
// protection.
ErrUnprotectSheet = errors.New("worksheet has set no protect")
// ErrUnprotectSheetPassword defined the error message on remove sheet
// protection with password verification failed.
ErrUnprotectSheetPassword = errors.New("worksheet protect password not match")
// ErrGroupSheets defined the error message on group sheets. // ErrGroupSheets defined the error message on group sheets.
ErrGroupSheets = errors.New("group worksheet must contain an active worksheet") ErrGroupSheets = errors.New("group worksheet must contain an active worksheet")
// ErrDataValidationFormulaLenth defined the error message for receiving a // ErrDataValidationFormulaLength defined the error message for receiving a
// data validation formula length that exceeds the limit. // data validation formula length that exceeds the limit.
ErrDataValidationFormulaLenth = errors.New("data validation must be 0-255 characters") ErrDataValidationFormulaLength = errors.New("data validation must be 0-255 characters")
// ErrDataValidationRange defined the error message on set decimal range // ErrDataValidationRange defined the error message on set decimal range
// exceeds limit. // exceeds limit.
ErrDataValidationRange = errors.New("data validation range exceeds limit") ErrDataValidationRange = errors.New("data validation range exceeds limit")
@ -164,5 +184,5 @@ var (
ErrSparkline = errors.New("must have the same number of 'Location' and 'Range' parameters") ErrSparkline = errors.New("must have the same number of 'Location' and 'Range' parameters")
// ErrSparklineStyle defined the error message on receive the invalid // ErrSparklineStyle defined the error message on receive the invalid
// sparkline Style parameters. // sparkline Style parameters.
ErrSparklineStyle = errors.New("parameter 'Style' must betweent 0-35") ErrSparklineStyle = errors.New("parameter 'Style' must between 0-35")
) )

@ -1160,13 +1160,44 @@ func TestHSL(t *testing.T) {
func TestProtectSheet(t *testing.T) { func TestProtectSheet(t *testing.T) {
f := NewFile() f := NewFile()
assert.NoError(t, f.ProtectSheet("Sheet1", nil)) sheetName := f.GetSheetName(0)
assert.NoError(t, f.ProtectSheet("Sheet1", &FormatSheetProtection{ assert.NoError(t, f.ProtectSheet(sheetName, nil))
// Test protect worksheet with XOR hash algorithm
assert.NoError(t, f.ProtectSheet(sheetName, &FormatSheetProtection{
Password: "password", Password: "password",
EditScenarios: false, EditScenarios: false,
})) }))
ws, err := f.workSheetReader(sheetName)
assert.NoError(t, err)
assert.Equal(t, "83AF", ws.SheetProtection.Password)
assert.NoError(t, f.SaveAs(filepath.Join("test", "TestProtectSheet.xlsx"))) assert.NoError(t, f.SaveAs(filepath.Join("test", "TestProtectSheet.xlsx")))
// Test protect worksheet with SHA-512 hash algorithm
assert.NoError(t, f.ProtectSheet(sheetName, &FormatSheetProtection{
AlgorithmName: "SHA-512",
Password: "password",
}))
ws, err = f.workSheetReader(sheetName)
assert.NoError(t, err)
assert.Equal(t, 24, len(ws.SheetProtection.SaltValue))
assert.Equal(t, 88, len(ws.SheetProtection.HashValue))
assert.Equal(t, int(sheetProtectionSpinCount), ws.SheetProtection.SpinCount)
// Test remove sheet protection with an incorrect password
assert.EqualError(t, f.UnprotectSheet(sheetName, "wrongPassword"), ErrUnprotectSheetPassword.Error())
// Test remove sheet protection with password verification
assert.NoError(t, f.UnprotectSheet(sheetName, "password"))
// Test protect worksheet with empty password
assert.NoError(t, f.ProtectSheet(sheetName, &FormatSheetProtection{}))
assert.Equal(t, "", ws.SheetProtection.Password)
// Test protect worksheet with password exceeds the limit length
assert.EqualError(t, f.ProtectSheet(sheetName, &FormatSheetProtection{
AlgorithmName: "MD4",
Password: strings.Repeat("s", MaxFieldLength+1),
}), ErrPasswordLengthInvalid.Error())
// Test protect worksheet with unsupported hash algorithm
assert.EqualError(t, f.ProtectSheet(sheetName, &FormatSheetProtection{
AlgorithmName: "RIPEMD-160",
Password: "password",
}), ErrUnsupportedHashAlgorithm.Error())
// Test protect not exists worksheet. // Test protect not exists worksheet.
assert.EqualError(t, f.ProtectSheet("SheetN", nil), "sheet SheetN is not exist") assert.EqualError(t, f.ProtectSheet("SheetN", nil), "sheet SheetN is not exist")
} }
@ -1176,12 +1207,30 @@ func TestUnprotectSheet(t *testing.T) {
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
t.FailNow() t.FailNow()
} }
// Test unprotect not exists worksheet. // Test remove protection on not exists worksheet.
assert.EqualError(t, f.UnprotectSheet("SheetN"), "sheet SheetN is not exist") assert.EqualError(t, f.UnprotectSheet("SheetN"), "sheet SheetN is not exist")
assert.NoError(t, f.UnprotectSheet("Sheet1")) assert.NoError(t, f.UnprotectSheet("Sheet1"))
assert.EqualError(t, f.UnprotectSheet("Sheet1", "password"), ErrUnprotectSheet.Error())
assert.NoError(t, f.SaveAs(filepath.Join("test", "TestUnprotectSheet.xlsx"))) assert.NoError(t, f.SaveAs(filepath.Join("test", "TestUnprotectSheet.xlsx")))
assert.NoError(t, f.Close()) assert.NoError(t, f.Close())
f = NewFile()
sheetName := f.GetSheetName(0)
assert.NoError(t, f.ProtectSheet(sheetName, &FormatSheetProtection{Password: "password"}))
// Test remove sheet protection with an incorrect password
assert.EqualError(t, f.UnprotectSheet(sheetName, "wrongPassword"), ErrUnprotectSheetPassword.Error())
// Test remove sheet protection with password verification
assert.NoError(t, f.UnprotectSheet(sheetName, "password"))
// Test with invalid salt value
assert.NoError(t, f.ProtectSheet(sheetName, &FormatSheetProtection{
AlgorithmName: "SHA-512",
Password: "password",
}))
ws, err := f.workSheetReader(sheetName)
assert.NoError(t, err)
ws.SheetProtection.SaltValue = "YWJjZA====="
assert.EqualError(t, f.UnprotectSheet(sheetName, "wrongPassword"), "illegal base64 data at input byte 8")
} }
func TestSetDefaultTimeStyle(t *testing.T) { func TestSetDefaultTimeStyle(t *testing.T) {

@ -222,7 +222,7 @@ func TestAddPivotTable(t *testing.T) {
PivotTableRange: "Sheet1!$G$2:$M$34", PivotTableRange: "Sheet1!$G$2:$M$34",
Rows: []PivotTableField{{Data: "Month", DefaultSubtotal: true}, {Data: "Year"}}, Rows: []PivotTableField{{Data: "Month", DefaultSubtotal: true}, {Data: "Year"}},
Columns: []PivotTableField{{Data: "Type", DefaultSubtotal: true}}, Columns: []PivotTableField{{Data: "Type", DefaultSubtotal: true}},
Data: []PivotTableField{{Data: "Sales", Subtotal: "-", Name: strings.Repeat("s", 256)}}, Data: []PivotTableField{{Data: "Sales", Subtotal: "-", Name: strings.Repeat("s", MaxFieldLength+1)}},
})) }))
// Test adjust range with invalid range // Test adjust range with invalid range

@ -1129,10 +1129,14 @@ func (f *File) SetHeaderFooter(sheet string, settings *FormatHeaderFooter) error
} }
// ProtectSheet provides a function to prevent other users from accidentally // ProtectSheet provides a function to prevent other users from accidentally
// or deliberately changing, moving, or deleting data in a worksheet. For // or deliberately changing, moving, or deleting data in a worksheet. The
// example, protect Sheet1 with protection settings: // optional field AlgorithmName specified hash algorithm, support XOR, MD4,
// MD5, SHA1, SHA256, SHA384, and SHA512 currently, if no hash algorithm
// specified, will be using the XOR algorithm as default. For example,
// protect Sheet1 with protection settings:
// //
// err := f.ProtectSheet("Sheet1", &excelize.FormatSheetProtection{ // err := f.ProtectSheet("Sheet1", &excelize.FormatSheetProtection{
// AlgorithmName: "SHA-512",
// Password: "password", // Password: "password",
// EditScenarios: false, // EditScenarios: false,
// }) // })
@ -1168,22 +1172,55 @@ func (f *File) ProtectSheet(sheet string, settings *FormatSheetProtection) error
Sort: settings.Sort, Sort: settings.Sort,
} }
if settings.Password != "" { if settings.Password != "" {
ws.SheetProtection.Password = genSheetPasswd(settings.Password) if settings.AlgorithmName == "" {
ws.SheetProtection.Password = genSheetPasswd(settings.Password)
return err
}
hashValue, saltValue, err := genISOPasswdHash(settings.Password, settings.AlgorithmName, "", int(sheetProtectionSpinCount))
if err != nil {
return err
}
ws.SheetProtection.Password = ""
ws.SheetProtection.AlgorithmName = settings.AlgorithmName
ws.SheetProtection.SaltValue = saltValue
ws.SheetProtection.HashValue = hashValue
ws.SheetProtection.SpinCount = int(sheetProtectionSpinCount)
} }
return err return err
} }
// UnprotectSheet provides a function to unprotect an Excel worksheet. // UnprotectSheet provides a function to remove protection for a sheet,
func (f *File) UnprotectSheet(sheet string) error { // specified the second optional password parameter to remove sheet
// protection with password verification.
func (f *File) UnprotectSheet(sheet string, password ...string) error {
ws, err := f.workSheetReader(sheet) ws, err := f.workSheetReader(sheet)
if err != nil { if err != nil {
return err return err
} }
// password verification
if len(password) > 0 {
if ws.SheetProtection == nil {
return ErrUnprotectSheet
}
if ws.SheetProtection.AlgorithmName == "" && ws.SheetProtection.Password != genSheetPasswd(password[0]) {
return ErrUnprotectSheetPassword
}
if ws.SheetProtection.AlgorithmName != "" {
// check with given salt value
hashValue, _, err := genISOPasswdHash(password[0], ws.SheetProtection.AlgorithmName, ws.SheetProtection.SaltValue, ws.SheetProtection.SpinCount)
if err != nil {
return err
}
if ws.SheetProtection.HashValue != hashValue {
return ErrUnprotectSheetPassword
}
}
}
ws.SheetProtection = nil ws.SheetProtection = nil
return err return err
} }
// trimSheetName provides a function to trim invaild characters by given worksheet // trimSheetName provides a function to trim invalid characters by given worksheet
// name. // name.
func trimSheetName(name string) string { func trimSheetName(name string) string {
if strings.ContainsAny(name, ":\\/?*[]") || utf8.RuneCountInString(name) > 31 { if strings.ContainsAny(name, ":\\/?*[]") || utf8.RuneCountInString(name) > 31 {

@ -838,6 +838,7 @@ type formatConditional struct {
// FormatSheetProtection directly maps the settings of worksheet protection. // FormatSheetProtection directly maps the settings of worksheet protection.
type FormatSheetProtection struct { type FormatSheetProtection struct {
AlgorithmName string
AutoFilter bool AutoFilter bool
DeleteColumns bool DeleteColumns bool
DeleteRows bool DeleteRows bool

Loading…
Cancel
Save