// Copyright 2018 Google Inc. All rights reserved.
|
//
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
// you may not use this file except in compliance with the License.
|
// You may obtain a copy of the License at:
|
//
|
// http://www.apache.org/licenses/LICENSE-2.0
|
//
|
// Unless required by applicable law or agreed to in writing, software
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
// See the License for the specific language governing permissions and
|
// limitations under the License.
|
|
@testable import TensorFlowLite
|
import XCTest
|
|
class TensorTests: XCTestCase {
|
|
// MARK: - Tensor
|
|
func testTensor_Init() {
|
let name = "InputTensor"
|
let dataType: TensorDataType = .uInt8
|
let shape = TensorShape(Constant.dimensions)
|
guard let data = name.data(using: .utf8) else { XCTFail("Data should not be nil."); return }
|
let quantizationParameters = QuantizationParameters(scale: 0.5, zeroPoint: 1)
|
let inputTensor = Tensor(
|
name: name,
|
dataType: dataType,
|
shape: shape,
|
data: data,
|
quantizationParameters: quantizationParameters
|
)
|
XCTAssertEqual(inputTensor.name, name)
|
XCTAssertEqual(inputTensor.dataType, dataType)
|
XCTAssertEqual(inputTensor.shape, shape)
|
XCTAssertEqual(inputTensor.data, data)
|
XCTAssertEqual(inputTensor.quantizationParameters, quantizationParameters)
|
}
|
|
// MARK: - TensorShape
|
|
func testTensorShape_InitWithArray() {
|
let shape = TensorShape(Constant.dimensions)
|
XCTAssertEqual(shape.rank, Constant.dimensions.count)
|
XCTAssertEqual(shape.dimensions, Constant.dimensions)
|
}
|
|
func testTensorShape_InitWithElements() {
|
let shape = TensorShape(2, 2, 3)
|
XCTAssertEqual(shape.rank, Constant.dimensions.count)
|
XCTAssertEqual(shape.dimensions, Constant.dimensions)
|
}
|
|
func testTensorShape_InitWithArrayLiteral() {
|
let shape: TensorShape = [2, 2, 3]
|
XCTAssertEqual(shape.rank, Constant.dimensions.count)
|
XCTAssertEqual(shape.dimensions, Constant.dimensions)
|
}
|
}
|
|
// MARK: - Constants
|
|
private enum Constant {
|
/// Array of 2 arrays of 2 arrays of 3 numbers: [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]].
|
static let dimensions = [2, 2, 3]
|
}
|
|
// MARK: - Extensions
|
|
extension TensorShape: Equatable {
|
public static func == (lhs: TensorShape, rhs: TensorShape) -> Bool {
|
return lhs.rank == rhs.rank && lhs.dimensions == rhs.dimensions
|
}
|
}
|
|
extension Tensor: Equatable {
|
public static func == (lhs: Tensor, rhs: Tensor) -> Bool {
|
return lhs.name == rhs.name && lhs.dataType == rhs.dataType && lhs.shape == rhs.shape &&
|
lhs.data == rhs.data && lhs.quantizationParameters == rhs.quantizationParameters
|
}
|
}
|