Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions pkg/sqlutil/interval.go
Comment thread
jmank88 marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package sqlutil

import (
"database/sql/driver"
"fmt"
"time"
)

// Interval represents a time.Duration stored as a Postgres interval type
type Interval time.Duration

// NewInterval creates Interval for specified duration
func NewInterval(d time.Duration) *Interval {
i := new(Interval)
*i = Interval(d)
return i
}

func (i Interval) Duration() time.Duration {
return time.Duration(i)
}

// MarshalText implements the text.Marshaler interface.
func (i Interval) MarshalText() ([]byte, error) {
return []byte(time.Duration(i).String()), nil
}

// UnmarshalText implements the text.Unmarshaler interface.
func (i *Interval) UnmarshalText(input []byte) error {
v, err := time.ParseDuration(string(input))
if err != nil {
return err
}
*i = Interval(v)
return nil
}

func (i *Interval) Scan(v interface{}) error {
if v == nil {
*i = Interval(time.Duration(0))
return nil
}
asInt64, is := v.(int64)
if !is {
return fmt.Errorf("models.Interval#Scan() wanted int64, got %T", v)
}
*i = Interval(time.Duration(asInt64) * time.Nanosecond)
return nil
}

func (i Interval) Value() (driver.Value, error) {
return time.Duration(i).Nanoseconds(), nil
}

func (i Interval) IsZero() bool {
return time.Duration(i) == time.Duration(0)
}
61 changes: 61 additions & 0 deletions pkg/sqlutil/interval_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package sqlutil

import (
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestNewInterval(t *testing.T) {
t.Parallel()

duration := 33 * time.Second
interval := NewInterval(duration)

require.Equal(t, duration, interval.Duration())
}

func TestInterval_IsZero(t *testing.T) {
t.Parallel()

i := NewInterval(0)
require.NotNil(t, i)
require.True(t, i.IsZero())

i = NewInterval(1)
require.NotNil(t, i)
require.False(t, i.IsZero())
}

func TestInterval_Scan_Value(t *testing.T) {
t.Parallel()

i := NewInterval(100)
require.NotNil(t, i)

val, err := i.Value()
require.NoError(t, err)

iNew := NewInterval(0)
err = iNew.Scan(val)
require.NoError(t, err)

require.Equal(t, i, iNew)
}

func TestInterval_MarshalText_UnmarshalText(t *testing.T) {
t.Parallel()

i := NewInterval(100)
require.NotNil(t, i)

txt, err := i.MarshalText()
require.NoError(t, err)

iNew := NewInterval(0)
err = iNew.UnmarshalText(txt)
require.NoError(t, err)

require.Equal(t, i, iNew)
}
Loading