// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import ModelSupport
import TensorFlow

/// An image with a label.
public typealias SegmentedImage = LabeledData<Tensor<Float>, Tensor<Int32>>

/// Types whose elements represent an image segmentation dataset (with both
/// training and validation data).
public protocol ImageSegmentationData {
  /// The type of the training data, represented as a sequence of epochs, which
  /// are collection of batches.
  associatedtype Training: Sequence
  where Training.Element: Collection, Training.Element.Element == SegmentedImage
  /// The type of the validation data, represented as a collection of batches.
  associatedtype Validation: Collection where Validation.Element == SegmentedImage
  /// Creates an instance from a given `batchSize`.
  init(batchSize: Int, on device: Device)
  /// The `training` epochs.
  var training: Training { get }
  /// The `validation` batches.
  var validation: Validation { get }

  // The following is probably going to be necessary since we can't extract that
  // information from `Epochs` or `Batches`.
  /// The number of samples in the `training` set.
  //var trainingSampleCount: Int {get}
  /// The number of samples in the `validation` set.
  //var validationSampleCount: Int {get}
}