diff --git a/lgensemble_io.go b/lgensemble_io.go index c25aa45..a7b8cc2 100644 --- a/lgensemble_io.go +++ b/lgensemble_io.go @@ -298,11 +298,10 @@ func LGEnsembleFromReader(reader *bufio.Reader, loadTransformation bool) (*Ensem return nil, err } - if err := params.Compare("version", "v2"); err != nil { - if err := params.Compare("version", "v3"); err != nil { - return nil, err - } + if err := params.CompareAll("version", []string{"v2", "v3", "v4"}); err != nil { + return nil, err } + nClasses, err := params.ToInt("num_class") if err != nil { return nil, err diff --git a/util/util.go b/util/util.go index 5aa0081..0acdac9 100644 --- a/util/util.go +++ b/util/util.go @@ -95,6 +95,20 @@ func (p *stringParams) Compare(key string, rhs string) error { return nil } +func (p *stringParams) CompareAll(key string, values []string) error { + valueStr, isFound := (*p)[key] + if !isFound { + return fmt.Errorf("no %s field", key) + } + + for _, v := range values { + if valueStr == v { + return nil + } + } + return fmt.Errorf("only %v are supported for %s, got %s", values, key, valueStr) +} + func (p *stringParams) ToStrSlice(key string) ([]string, error) { valueStr, isFound := (*p)[key] if !isFound { @@ -267,7 +281,7 @@ func SoftmaxFloat64Slice(rawValues []float64, outputValues []float64, startIndex } if sum != 0.0 { inv_sum := 1.0 / sum - for i := startIndex; i < startIndex + len(rawValues); i++ { + for i := startIndex; i < startIndex+len(rawValues); i++ { outputValues[i] *= inv_sum } }