Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorBoardLogger"
uuid = "899adc3e-224a-11e9-021f-63837185c80f"
authors = ["Filippo Vicentini <filippovicentini@gmail.com>"]
version = "0.1.26"
authors = ["Filippo Vicentini <filippovicentini@gmail.com>"]

[deps]
CRC32c = "8bf52ea8-c179-5cab-976a-9e18b702a9bc"
Expand All @@ -14,10 +14,10 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
[compat]
FileIO = "1.2.3"
ImageCore = "0.8.1, 0.9, 0.10"
ProtoBuf = "1.0.11"
ProtoBuf = "1"
Requires = "0.5, 1"
StatsBase = "0.27, 0.28, 0.29, 0.30, 0.31, 0.32, 0.33, 0.34"
julia = "1.6"
julia = "1.10"

[extras]
Cairo = "159f3aea-2a34-519c-b102-8c37f9878175"
Expand Down
2 changes: 1 addition & 1 deletion gen/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
ProtoBuf = "3349acd9-ac6a-5e09-bcdb-63829b23a429"

[compat]
ProtoBuf = "0.9.1"
ProtoBuf = "1"
17 changes: 16 additions & 1 deletion gen/compile_proto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,23 @@ plugins = ["custom_scalar", "hparams", "text"]

append!(files_to_include, (process_module("tensorboard/plugins/$plugin", base_module="tensorboard") for plugin in plugins)...)

const PLUGIN_MODULE_NAMES = Dict(
"text" => "tensorboard_plugin_text",
"hparams" => "tensorboard_plugin_hparams",
"custom_scalar" => "tensorboard_plugin_custom_scalar",
)

for (plugin, module_name) in PLUGIN_MODULE_NAMES
wrapper = out_dir / "tensorboard" / "plugins" / plugin / "tensorboard" / "tensorboard.jl"
if isfile(wrapper)
content = read(wrapper, String)
write(wrapper, replace(content, "module tensorboard" => "module $module_name", "end # module tensorboard" => "end # module $module_name"))
end
end

# files_to_include contains all the proto files, can be used for printing and inspection
println("generated code for \n$files_to_include")

# Finally move the output directory to the src folder
mv(out_dir, TBL_root/"src"/"protojl")
rm(TBL_root / "src" / "protojl"; force = true, recursive = true)
mv(out_dir, TBL_root / "src" / "protojl")
4 changes: 3 additions & 1 deletion gen/download_proto_source
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ mv proto/tensorboard/compat/proto/struct.proto proto/tensorboard/compat/proto/st
if [[ $(uname) == "Darwin" ]]; then
echo "Workaround for sed on macos."
LC_ALL=C
grep -rl "compat/proto/struct.proto" . | xargs sed -i '' -e 's+compat/proto/struct\.proto+compat/proto/struct_tb\.proto+g'
else
grep -rl "compat/proto/struct.proto" . | xargs sed -i -e 's+compat/proto/struct\.proto+compat/proto/struct_tb\.proto+g'
fi
grep -rl "compat/proto/struct.proto" . | xargs sed -i '' -e 's+compat/proto/struct\.proto+compat/proto/struct_tb\.proto+g'

cd ..
mv tmp/proto proto
Expand Down
5 changes: 0 additions & 5 deletions src/TensorBoardLogger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,6 @@ using ImageCore: colorview, channelview
using ImageCore: Colorant, Gray, GrayA, RGB, RGBA
using FileIO: FileIO, @format_str, Stream, save, load

# hasproperty is not defined before 1.2. This is Compat.hasproperty
if VERSION < v"1.2.0-DEV.272"
using ProtoBuf: hasproperty
end

#TODO: Is there a more lightweight package for compmuting an histogram?
using StatsBase: Histogram, fit

Expand Down
45 changes: 1 addition & 44 deletions src/hparams.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import .tensorboard_plugin_hparams.hparams: var"#DataType" as HParamDataType, DatasetType as HDatasetType
import .tensorboard_plugin_hparams.google.protobuf: ListValue as HListValue, Value as HValue
import .tensorboard_plugin_hparams.hparams as HP
import ProtoBuf as PB

struct HParamRealDomain
min_value::Float64
Expand Down Expand Up @@ -66,55 +65,13 @@ function hparam_info(c::HParamConfig)

dtype = _to_proto_hparam_dtype(Val(datatype))
converted_domain = _convert_hparam_domain(domain)
return HP.HParamInfo(c.name, c.displayname, c.description, dtype, converted_domain)
return HP.HParamInfo(c.name, c.displayname, c.description, dtype, converted_domain, false)
end
function metric_info(c::MetricConfig)
mname = HP.MetricName("", c.name)
return HP.MetricInfo(mname, c.displayname, c.description, HDatasetType.DATASET_UNKNOWN)
end

# Dictionary serialisation in ProtoBuf does not work for this specific map type
# and must be overloaded so that it can be parsed. The format was derived by
# looking at the binary output of a log file created by tensorboardX.
# These protobuf overloads should be removed once https://github.com/JuliaIO/ProtoBuf.jl/pull/234 is merged.
function PB.encode(e::ProtoEncoder, i::Int, x::Dict{String,HValue})
for (k, v) in x
PB.Codecs.encode_tag(e, 1, PB.Codecs.LENGTH_DELIMITED)
total_size = PB.Codecs._encoded_size(k, 1) + PB.Codecs._encoded_size(v, 2)
PB.Codecs.vbyte_encode(e.io, UInt32(total_size)) # Add two for the wire type and length
PB.Codecs.encode(e, 1, k)
PB.Codecs.encode(e, 2, v)
end
return nothing
end

# Similarly, we must overload the size calculation to take into account the new
# format.
function PB.Codecs._encoded_size(d::Dict{String,HValue}, i::Int)
mapreduce(x->begin
total_size = PB.Codecs._encoded_size(x.first, 1) + PB.Codecs._encoded_size(x.second, 2)
return 1 + PB.Codecs._varint_size(total_size) + total_size
end, +, d, init=0)
end

function PB.Codecs.decode!(d::ProtoDecoder, buffer::Dict{String,HValue})
len = PB.Codecs.vbyte_decode(d.io, UInt32)
endpos = position(d.io) + len
while position(d.io) < endpos
pair_field_number, pair_wire_type = PB.Codecs.decode_tag(d)
pair_len = PB.Codecs.vbyte_decode(d.io, UInt32)
pair_end_pos = position(d.io) + pair_len
field_number, wire_type = PB.Codecs.decode_tag(d)
key = PB.Codecs.decode(d, K)
field_number, wire_type = PB.Codecs.decode_tag(d)
val = PB.Codecs.decode(d, Ref{V})
@assert position(d.io) == pair_end_pos
buffer[key] = val
end
@assert position(d.io) == endpos
nothing
end

"""
write_hparams!(logger::TBLogger, hparams::Dict{String, Any}, metrics::AbstractArray{String})

Expand Down
11 changes: 6 additions & 5 deletions src/protojl/tensorboard/google/protobuf/any_pb.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,31 @@
# Autogenerated using ProtoBuf.jl v1.0.11 on 2023-06-19T18:18:24.779
# original file: /home/lior/.julia/dev/ProtoBuf/src/google/protobuf/any.proto (proto3 syntax)
# Autogenerated using ProtoBuf.jl v1.3.0
# original file: google/protobuf/any.proto (proto3 syntax)

import ProtoBuf as PB
using ProtoBuf: OneOf
using ProtoBuf.EnumX: @enumx

export var"#Any"


struct var"#Any"
type_url::String
value::Vector{UInt8}
end
PB.default_values(::Type{var"#Any"}) = (;type_url = "", value = UInt8[])
PB.field_numbers(::Type{var"#Any"}) = (;type_url = 1, value = 2)

function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:var"#Any"})
function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:var"#Any"}, _endpos::Int=0, _group::Bool=false)
type_url = ""
value = UInt8[]
while !PB.message_done(d)
while !PB.message_done(d, _endpos, _group)
field_number, wire_type = PB.decode_tag(d)
if field_number == 1
type_url = PB.decode(d, String)
elseif field_number == 2
value = PB.decode(d, Vector{UInt8})
else
PB.skip(d, wire_type)
Base.skip(d, wire_type)
end
end
return var"#Any"(type_url, value)
Expand Down
67 changes: 34 additions & 33 deletions src/protojl/tensorboard/google/protobuf/wrappers_pb.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Autogenerated using ProtoBuf.jl v1.0.11 on 2023-06-19T18:18:24.780
# original file: /home/lior/.julia/dev/ProtoBuf/src/google/protobuf/wrappers.proto (proto3 syntax)
# Autogenerated using ProtoBuf.jl v1.3.0
# original file: google/protobuf/wrappers.proto (proto3 syntax)

import ProtoBuf as PB
using ProtoBuf: OneOf
Expand All @@ -8,20 +8,21 @@ using ProtoBuf.EnumX: @enumx
export BoolValue, Int64Value, FloatValue, Int32Value, DoubleValue, UInt64Value, UInt32Value
export BytesValue, StringValue


struct BoolValue
value::Bool
end
PB.default_values(::Type{BoolValue}) = (;value = false)
PB.field_numbers(::Type{BoolValue}) = (;value = 1)

function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:BoolValue})
function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:BoolValue}, _endpos::Int=0, _group::Bool=false)
value = false
while !PB.message_done(d)
while !PB.message_done(d, _endpos, _group)
field_number, wire_type = PB.decode_tag(d)
if field_number == 1
value = PB.decode(d, Bool)
else
PB.skip(d, wire_type)
Base.skip(d, wire_type)
end
end
return BoolValue(value)
Expand All @@ -44,14 +45,14 @@ end
PB.default_values(::Type{Int64Value}) = (;value = zero(Int64))
PB.field_numbers(::Type{Int64Value}) = (;value = 1)

function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:Int64Value})
function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:Int64Value}, _endpos::Int=0, _group::Bool=false)
value = zero(Int64)
while !PB.message_done(d)
while !PB.message_done(d, _endpos, _group)
field_number, wire_type = PB.decode_tag(d)
if field_number == 1
value = PB.decode(d, Int64)
else
PB.skip(d, wire_type)
Base.skip(d, wire_type)
end
end
return Int64Value(value)
Expand All @@ -74,27 +75,27 @@ end
PB.default_values(::Type{FloatValue}) = (;value = zero(Float32))
PB.field_numbers(::Type{FloatValue}) = (;value = 1)

function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:FloatValue})
function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:FloatValue}, _endpos::Int=0, _group::Bool=false)
value = zero(Float32)
while !PB.message_done(d)
while !PB.message_done(d, _endpos, _group)
field_number, wire_type = PB.decode_tag(d)
if field_number == 1
value = PB.decode(d, Float32)
else
PB.skip(d, wire_type)
Base.skip(d, wire_type)
end
end
return FloatValue(value)
end

function PB.encode(e::PB.AbstractProtoEncoder, x::FloatValue)
initpos = position(e.io)
x.value != zero(Float32) && PB.encode(e, 1, x.value)
x.value !== zero(Float32) && PB.encode(e, 1, x.value)
return position(e.io) - initpos
end
function PB._encoded_size(x::FloatValue)
encoded_size = 0
x.value != zero(Float32) && (encoded_size += PB._encoded_size(x.value, 1))
x.value !== zero(Float32) && (encoded_size += PB._encoded_size(x.value, 1))
return encoded_size
end

Expand All @@ -104,14 +105,14 @@ end
PB.default_values(::Type{Int32Value}) = (;value = zero(Int32))
PB.field_numbers(::Type{Int32Value}) = (;value = 1)

function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:Int32Value})
function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:Int32Value}, _endpos::Int=0, _group::Bool=false)
value = zero(Int32)
while !PB.message_done(d)
while !PB.message_done(d, _endpos, _group)
field_number, wire_type = PB.decode_tag(d)
if field_number == 1
value = PB.decode(d, Int32)
else
PB.skip(d, wire_type)
Base.skip(d, wire_type)
end
end
return Int32Value(value)
Expand All @@ -134,27 +135,27 @@ end
PB.default_values(::Type{DoubleValue}) = (;value = zero(Float64))
PB.field_numbers(::Type{DoubleValue}) = (;value = 1)

function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:DoubleValue})
function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:DoubleValue}, _endpos::Int=0, _group::Bool=false)
value = zero(Float64)
while !PB.message_done(d)
while !PB.message_done(d, _endpos, _group)
field_number, wire_type = PB.decode_tag(d)
if field_number == 1
value = PB.decode(d, Float64)
else
PB.skip(d, wire_type)
Base.skip(d, wire_type)
end
end
return DoubleValue(value)
end

function PB.encode(e::PB.AbstractProtoEncoder, x::DoubleValue)
initpos = position(e.io)
x.value != zero(Float64) && PB.encode(e, 1, x.value)
x.value !== zero(Float64) && PB.encode(e, 1, x.value)
return position(e.io) - initpos
end
function PB._encoded_size(x::DoubleValue)
encoded_size = 0
x.value != zero(Float64) && (encoded_size += PB._encoded_size(x.value, 1))
x.value !== zero(Float64) && (encoded_size += PB._encoded_size(x.value, 1))
return encoded_size
end

Expand All @@ -164,14 +165,14 @@ end
PB.default_values(::Type{UInt64Value}) = (;value = zero(UInt64))
PB.field_numbers(::Type{UInt64Value}) = (;value = 1)

function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:UInt64Value})
function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:UInt64Value}, _endpos::Int=0, _group::Bool=false)
value = zero(UInt64)
while !PB.message_done(d)
while !PB.message_done(d, _endpos, _group)
field_number, wire_type = PB.decode_tag(d)
if field_number == 1
value = PB.decode(d, UInt64)
else
PB.skip(d, wire_type)
Base.skip(d, wire_type)
end
end
return UInt64Value(value)
Expand All @@ -194,14 +195,14 @@ end
PB.default_values(::Type{UInt32Value}) = (;value = zero(UInt32))
PB.field_numbers(::Type{UInt32Value}) = (;value = 1)

function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:UInt32Value})
function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:UInt32Value}, _endpos::Int=0, _group::Bool=false)
value = zero(UInt32)
while !PB.message_done(d)
while !PB.message_done(d, _endpos, _group)
field_number, wire_type = PB.decode_tag(d)
if field_number == 1
value = PB.decode(d, UInt32)
else
PB.skip(d, wire_type)
Base.skip(d, wire_type)
end
end
return UInt32Value(value)
Expand All @@ -224,14 +225,14 @@ end
PB.default_values(::Type{BytesValue}) = (;value = UInt8[])
PB.field_numbers(::Type{BytesValue}) = (;value = 1)

function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:BytesValue})
function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:BytesValue}, _endpos::Int=0, _group::Bool=false)
value = UInt8[]
while !PB.message_done(d)
while !PB.message_done(d, _endpos, _group)
field_number, wire_type = PB.decode_tag(d)
if field_number == 1
value = PB.decode(d, Vector{UInt8})
else
PB.skip(d, wire_type)
Base.skip(d, wire_type)
end
end
return BytesValue(value)
Expand All @@ -254,14 +255,14 @@ end
PB.default_values(::Type{StringValue}) = (;value = "")
PB.field_numbers(::Type{StringValue}) = (;value = 1)

function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:StringValue})
function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:StringValue}, _endpos::Int=0, _group::Bool=false)
value = ""
while !PB.message_done(d)
while !PB.message_done(d, _endpos, _group)
field_number, wire_type = PB.decode_tag(d)
if field_number == 1
value = PB.decode(d, String)
else
PB.skip(d, wire_type)
Base.skip(d, wire_type)
end
end
return StringValue(value)
Expand Down
Loading
Loading