diff --git a/pkg/sqlutil/interval.go b/pkg/sqlutil/interval.go new file mode 100644 index 0000000000..1491b2986b --- /dev/null +++ b/pkg/sqlutil/interval.go @@ -0,0 +1,57 @@ +package sqlutil + +import ( + "database/sql/driver" + "fmt" + "time" +) + +// Interval represents a time.Duration stored as a Postgres interval type +type Interval time.Duration + +// NewInterval creates Interval for specified duration +func NewInterval(d time.Duration) *Interval { + i := new(Interval) + *i = Interval(d) + return i +} + +func (i Interval) Duration() time.Duration { + return time.Duration(i) +} + +// MarshalText implements the text.Marshaler interface. +func (i Interval) MarshalText() ([]byte, error) { + return []byte(time.Duration(i).String()), nil +} + +// UnmarshalText implements the text.Unmarshaler interface. +func (i *Interval) UnmarshalText(input []byte) error { + v, err := time.ParseDuration(string(input)) + if err != nil { + return err + } + *i = Interval(v) + return nil +} + +func (i *Interval) Scan(v interface{}) error { + if v == nil { + *i = Interval(time.Duration(0)) + return nil + } + asInt64, is := v.(int64) + if !is { + return fmt.Errorf("models.Interval#Scan() wanted int64, got %T", v) + } + *i = Interval(time.Duration(asInt64) * time.Nanosecond) + return nil +} + +func (i Interval) Value() (driver.Value, error) { + return time.Duration(i).Nanoseconds(), nil +} + +func (i Interval) IsZero() bool { + return time.Duration(i) == time.Duration(0) +} diff --git a/pkg/sqlutil/interval_test.go b/pkg/sqlutil/interval_test.go new file mode 100644 index 0000000000..32b6b0c432 --- /dev/null +++ b/pkg/sqlutil/interval_test.go @@ -0,0 +1,61 @@ +package sqlutil + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestNewInterval(t *testing.T) { + t.Parallel() + + duration := 33 * time.Second + interval := NewInterval(duration) + + require.Equal(t, duration, interval.Duration()) +} + +func TestInterval_IsZero(t *testing.T) { + t.Parallel() + + i := NewInterval(0) + require.NotNil(t, i) + require.True(t, i.IsZero()) + + i = NewInterval(1) + require.NotNil(t, i) + require.False(t, i.IsZero()) +} + +func TestInterval_Scan_Value(t *testing.T) { + t.Parallel() + + i := NewInterval(100) + require.NotNil(t, i) + + val, err := i.Value() + require.NoError(t, err) + + iNew := NewInterval(0) + err = iNew.Scan(val) + require.NoError(t, err) + + require.Equal(t, i, iNew) +} + +func TestInterval_MarshalText_UnmarshalText(t *testing.T) { + t.Parallel() + + i := NewInterval(100) + require.NotNil(t, i) + + txt, err := i.MarshalText() + require.NoError(t, err) + + iNew := NewInterval(0) + err = iNew.UnmarshalText(txt) + require.NoError(t, err) + + require.Equal(t, i, iNew) +}