Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions tensorflow_text/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ cc_library(
],
)

cc_library(
name = "row_splits_validator",
hdrs = ["row_splits_validator.h"],
compatible_with = ["//buildenv/target:prod"],
deps = [
"@com_google_absl//absl/status",
"@com_google_absl//absl/types:span",
],
)

cc_test(
name = "boise_offset_converter_test",
size = "small",
Expand Down Expand Up @@ -62,6 +72,7 @@ tf_cc_library(
],
deps = [
":boise_offset_converter",
":row_splits_validator",
"@com_google_absl//absl/status",
# lite/kernels/shim:op_kernel tensorflow dep,
# lite/kernels/shim:shape tensorflow dep,
Expand Down Expand Up @@ -112,6 +123,7 @@ tf_cc_library(
],
deps = [
":byte_splitter",
":row_splits_validator",
"@com_google_absl//absl/status",
# lite/kernels/shim:op_kernel tensorflow dep,
# lite/kernels/shim:shape tensorflow dep,
Expand Down Expand Up @@ -502,6 +514,7 @@ cc_library(
hdrs = ["fast_wordpiece_tokenizer_kernel_template.h"],
deps = [
":fast_wordpiece_tokenizer",
":row_splits_validator",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
# lite/kernels/shim:op_kernel tensorflow dep,
Expand Down Expand Up @@ -615,6 +628,7 @@ tf_cc_library(
# tf/platform:tstring tensorflow dep,
],
deps = [
":row_splits_validator",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
Expand Down Expand Up @@ -801,6 +815,7 @@ tf_cc_library(
],
deps = [
":round_robin_trimmer",
":row_splits_validator",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
Expand Down Expand Up @@ -1020,6 +1035,7 @@ tf_cc_library(
# tf:lib tensorflow dep,
],
deps = [
":row_splits_validator",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
Expand Down Expand Up @@ -1288,6 +1304,7 @@ cc_library(
hdrs = ["phrase_tokenizer_kernel_template.h"],
deps = [
":phrase_tokenizer",
":row_splits_validator",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
# lite/kernels/shim:op_kernel tensorflow dep,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@
#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BOISE_OFFSET_CONVERTER_KERNEL_TEMPLATE_H_

#include <cstdint>
#include <iostream>
#include <ostream>
#include <string>
#include <vector>

#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/lite/kernels/shim/op_kernel.h"
#include "tensorflow/lite/kernels/shim/shape.h"
#include "tensorflow/lite/kernels/shim/status_macros.h"
#include "tensorflow_text/core/kernels/boise_offset_converter.h"
#include "tensorflow_text/core/kernels/row_splits_validator.h"

namespace tensorflow {
namespace text {
Expand Down Expand Up @@ -304,6 +305,16 @@ absl::Status OffsetsToBoiseTagsOp<Rt>::Invoke(InvokeContext* context) {
}
}

SH_RETURN_IF_ERROR(ValidateRowSplits<int64_t>(
absl::MakeConstSpan(input_token_begin_row_splits_vec.Ptr(),
input_token_begin_row_splits_vec.Dim(0)),
input_token_begin_offsets_vec.Dim(0)));

SH_RETURN_IF_ERROR(ValidateRowSplits<int64_t>(
absl::MakeConstSpan(input_span_begin_row_splits_vec.Ptr(),
input_span_begin_row_splits_vec.Dim(0)),
input_span_begin_offsets_vec.Dim(0)));

// Outputs
std::vector<std::string> boise_tags;
std::vector<int32_t> input_token_begin_offsets_vec_i;
Expand Down Expand Up @@ -562,6 +573,11 @@ absl::Status BoiseTagsToOffsetsOp<Rt>::Invoke(InvokeContext* context) {
}
}

SH_RETURN_IF_ERROR(ValidateRowSplits<int64_t>(
absl::MakeConstSpan(input_token_begin_row_splits_vec.Ptr(),
input_token_begin_row_splits_vec.Dim(0)),
input_token_begin_offsets_vec.Dim(0)));

// Outputs
std::vector<int32_t> span_begin_offsets;
std::vector<int32_t> span_end_offsets;
Expand Down
12 changes: 11 additions & 1 deletion tensorflow_text/core/kernels/byte_splitter_kernel_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@
#ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BYTE_SPLITTER_KERNEL_TEMPLATE_H_
#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BYTE_SPLITTER_KERNEL_TEMPLATE_H_

#include <iostream>
#include <cstdint>
#include <string>
#include <vector>

#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/lite/kernels/shim/op_kernel.h"
#include "tensorflow/lite/kernels/shim/shape.h"
#include "tensorflow/lite/kernels/shim/status_macros.h"
#include "tensorflow_text/core/kernels/byte_splitter.h"
#include "tensorflow_text/core/kernels/row_splits_validator.h"

namespace tensorflow {
namespace text {
Expand Down Expand Up @@ -277,6 +280,13 @@ template <tflite::shim::Runtime Rt>
context->GetInput(kInputRowSplits));
const auto in_splits = in_splits_view->template As<int64_t, 1>();

if (starts.Dim(0) != ends.Dim(0)) {
return absl::InvalidArgumentError(
"starts and ends must have the same size.");
}
SH_RETURN_IF_ERROR(ValidateRowSplits<int64_t>(
absl::MakeConstSpan(in_splits.Ptr(), in_splits.Dim(0)), starts.Dim(0)));

ByteSplitter splitter;

// Outputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,19 @@
#ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_WORDPIECE_TOKENIZER_KERNEL_TEMPLATE_H_
#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_WORDPIECE_TOKENIZER_KERNEL_TEMPLATE_H_

#include <cstdint>
#include <string>
#include <vector>

#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/lite/kernels/shim/op_kernel.h"
#include "tensorflow/lite/kernels/shim/shape.h"
#include "tensorflow/lite/kernels/shim/status_macros.h"
#include "tensorflow_text/core/kernels/fast_wordpiece_tokenizer.h"
#include "tensorflow_text/core/kernels/row_splits_validator.h"

namespace tensorflow {
namespace text {
Expand Down Expand Up @@ -147,7 +155,7 @@ absl::Status FastWordpieceTokenizeWithOffsetsOp<Rt>::Invoke(
// Create() is very cheap.
auto fast_wordpiece_tokenizer =
::tensorflow::text::FastWordpieceTokenizer::Create(
wp_model->template Data<uint8>().data());
wp_model->template Data<uint8_t>().data());
SH_RETURN_IF_ERROR(fast_wordpiece_tokenizer.status());

// TODO(xysong): Optimize based on which information below is requested.
Expand Down Expand Up @@ -180,13 +188,13 @@ absl::Status FastWordpieceTokenizeWithOffsetsOp<Rt>::Invoke(
SH_RETURN_IF_ERROR(this->template FillOutputTensor<std::string,
tensorflow::tstring>(
subwords, kOutputSubwords, context));
SH_RETURN_IF_ERROR(this->template FillOutputTensor<int, int64>(
SH_RETURN_IF_ERROR(this->template FillOutputTensor<int, int64_t>(
subword_ids, kOutputIds, context));
SH_RETURN_IF_ERROR(this->template FillOutputTensor<int, int64>(
SH_RETURN_IF_ERROR(this->template FillOutputTensor<int, int64_t>(
row_splits, kOutputRowSplits, context));
SH_RETURN_IF_ERROR(this->template FillOutputTensor<int, int64>(
SH_RETURN_IF_ERROR(this->template FillOutputTensor<int, int64_t>(
begin_offset, kStartValues, context));
SH_RETURN_IF_ERROR(this->template FillOutputTensor<int, int64>(
SH_RETURN_IF_ERROR(this->template FillOutputTensor<int, int64_t>(
end_offset, kEndValues, context));

return absl::OkStatus();
Expand Down Expand Up @@ -311,15 +319,19 @@ absl::Status FastWordpieceDetokenizeOp<Rt>::Invoke(InvokeContext* context) {

SH_ASSIGN_OR_RETURN(const auto input_row_splits,
context->GetInput(kInputRowSplits));
const auto& row_splits_vec = input_row_splits->template As<int64, 1>();
const auto& row_splits_vec = input_row_splits->template As<int64_t, 1>();

SH_RETURN_IF_ERROR(ValidateRowSplits<int64_t>(
absl::MakeConstSpan(row_splits_vec.Ptr(), row_splits_vec.Dim(0)),
values_vec.Dim(0)));

SH_ASSIGN_OR_RETURN(const auto wp_model, context->GetInput(kWpModel));
// OK to create on every call because FastWordpieceTokenizer is a
// lightweight, memory-mapped wrapper on `wp_model` tensor, and thus
// Create() is very cheap.
auto fast_wordpiece_tokenizer =
::tensorflow::text::FastWordpieceTokenizer::Create(
wp_model->template Data<uint8>().data());
wp_model->template Data<uint8_t>().data());
SH_RETURN_IF_ERROR(fast_wordpiece_tokenizer.status());

std::vector<std::string> sentences;
Expand Down
21 changes: 19 additions & 2 deletions tensorflow_text/core/kernels/ngrams_kernel_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,23 @@ limitations under the License.
#ifndef TENSORFLOW_TEXT_CORE_KERNELS_NGRAMS_KERNEL_TEMPLATE_H_
#define TENSORFLOW_TEXT_CORE_KERNELS_NGRAMS_KERNEL_TEMPLATE_H_

#include <algorithm>
#include <cstdint>
#include <cstring>
#include <string>
#include <vector>

#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/lite/kernels/shim/op_kernel.h"
#include "tensorflow/lite/kernels/shim/shape.h"
#include "tensorflow/lite/kernels/shim/status_macros.h"
#include "tensorflow/lite/kernels/shim/tensor_view.h"
#include "tensorflow_text/core/kernels/row_splits_validator.h"

namespace tensorflow {
namespace text {
Expand Down Expand Up @@ -191,6 +200,8 @@ class NgramsStringJoin : public tflite::shim::OpKernelShim<NgramsStringJoin,
Shape(input_tensor_row_splits->Shape())));
const auto input_buffer =
input_tensor_row_splits->template Data<Tsplits>();
SH_RETURN_IF_ERROR(ValidateRowSplits<Tsplits>(
absl::MakeConstSpan(input_buffer.data(), input_buffer.size())));
const auto output_buffer =
output_tensor_row_splits->template Data<Tsplits>();
std::memcpy(output_buffer.data(), input_buffer.data(),
Expand All @@ -214,6 +225,12 @@ class NgramsStringJoin : public tflite::shim::OpKernelShim<NgramsStringJoin,
const auto input_values_data =
input_values->template Data<tensorflow::tstring>();

if (ctx->NumOutputs() != 1) {
SH_RETURN_IF_ERROR(ValidateRowSplits<Tsplits>(
absl::MakeConstSpan(input_row_splits, n_row_splits),
input_values_data.size()));
}

// Create ngrams by looping through the innermost input splits.
std::vector<std::string> buffer;
for (int i = 0; i < n_row_splits - 1; ++i) {
Expand Down Expand Up @@ -247,8 +264,8 @@ class NgramsStringJoin : public tflite::shim::OpKernelShim<NgramsStringJoin,
}

protected:
inline static Shape OutputValuesTensorShape(const Shape& input_values_shape,
const int64_t width) {
static Shape OutputValuesTensorShape(const Shape& input_values_shape,
const int64_t width) {
// If the input shape is unknown, so is the output shape.
if (input_values_shape.Rank() == input_values_shape.kUnknownRank)
return input_values_shape;
Expand Down
22 changes: 17 additions & 5 deletions tensorflow_text/core/kernels/phrase_tokenizer_kernel_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,19 @@
#ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_PHRASE_TOKENIZER_KERNEL_TEMPLATE_H_
#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_PHRASE_TOKENIZER_KERNEL_TEMPLATE_H_

#include <cstdint>
#include <string>
#include <vector>

#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/lite/kernels/shim/op_kernel.h"
#include "tensorflow/lite/kernels/shim/shape.h"
#include "tensorflow/lite/kernels/shim/status_macros.h"
#include "tensorflow_text/core/kernels/phrase_tokenizer.h"
#include "tensorflow_text/core/kernels/row_splits_validator.h"

namespace tensorflow {
namespace text {
Expand Down Expand Up @@ -126,7 +134,7 @@ absl::Status PhraseTokenizeOp<Rt>::Invoke(InvokeContext* context) {
// lightweight, memory-mapped wrapper on `phrase_model` tensor, and thus
// Create() is very cheap.
auto phrase_tokenizer = ::tensorflow::text::PhraseTokenizer::Create(
phrase_model->template Data<uint8>().data());
phrase_model->template Data<uint8_t>().data());
SH_RETURN_IF_ERROR(phrase_tokenizer.status());

std::vector<std::string> subwords;
Expand Down Expand Up @@ -159,13 +167,13 @@ absl::Status PhraseTokenizeOp<Rt>::Invoke(InvokeContext* context) {
kOutputIds,
Shape({static_cast<int>(
subword_ids.size())}))); /* same shape as `output_subwords` */
auto output_ids_vec = output_ids->template As<int64, 1>();
auto output_ids_vec = output_ids->template As<int64_t, 1>();

SH_ASSIGN_OR_RETURN(
auto output_row_splits,
context->GetOutput(kOutputRowSplits,
Shape({static_cast<int>(row_splits.size())})));
auto output_row_splits_vec = output_row_splits->template As<int64, 1>();
auto output_row_splits_vec = output_row_splits->template As<int64_t, 1>();

for (int i = 0; i < subwords.size(); ++i) {
output_subwords_vec(i) = subwords[i];
Expand Down Expand Up @@ -299,14 +307,18 @@ absl::Status PhraseDetokenizeOp<Rt>::Invoke(InvokeContext* context) {

SH_ASSIGN_OR_RETURN(const auto input_row_splits,
context->GetInput(kInputRowSplits));
const auto& row_splits_vec = input_row_splits->template As<int64, 1>();
const auto& row_splits_vec = input_row_splits->template As<int64_t, 1>();

SH_RETURN_IF_ERROR(ValidateRowSplits<int64_t>(
absl::MakeConstSpan(row_splits_vec.Ptr(), row_splits_vec.Dim(0)),
values_vec.Dim(0)));

SH_ASSIGN_OR_RETURN(const auto phrase_model, context->GetInput(kPhraseModel));
// OK to create on every call because PhraseTokenizer is a
// lightweight, memory-mapped wrapper on `phrase_model` tensor, and thus
// Create() is very cheap.
auto phrase_tokenizer = ::tensorflow::text::PhraseTokenizer::Create(
phrase_model->template Data<uint8>().data());
phrase_model->template Data<uint8_t>().data());
SH_RETURN_IF_ERROR(phrase_tokenizer.status());

std::vector<std::string> sentences;
Expand Down
8 changes: 5 additions & 3 deletions tensorflow_text/core/kernels/round_robin_trimmer.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_ROUND_ROBIN_TRIMMER_H_

#include <algorithm>
#include <cstdint>
#include <functional>
#include <utility>
#include <vector>
#include "tensorflow_text/core/kernels/trimmer.h"

#include "tensorflow_text/core/kernels/trimmer.h"

namespace tensorflow {
namespace text {
Expand Down Expand Up @@ -153,7 +154,7 @@ std::vector<Mask> RoundRobinTrimmer<T, Tsplits>::GenerateMasksInternal(
std::vector<Mask> masks(end - begin);
auto m = masks.begin();
for (auto it = begin; it != end; ++it, ++m) {
m->reserve(it->back());
m->reserve(std::max(static_cast<Tsplits>(0), it->empty() ? 0 : it->back()));
}
// Process all batches, updating the masks a batch at a time.
ProcessSplitsByBatch(begin, end, [&masks](std::vector<Row>* rows) {
Expand Down Expand Up @@ -305,7 +306,8 @@ void RoundRobinTrimmer<T, Tsplits>::ProcessSplitsByBatch(
int idx = 0;
for (auto i = begin; i < end; ++i, ++idx) {
value_row_sizes[idx].idx = idx;
value_row_sizes[idx].size = (*i)[batch_idx + 1] - (*i)[batch_idx];
Tsplits row_size = (*i)[batch_idx + 1] - (*i)[batch_idx];
value_row_sizes[idx].size = row_size < 0 ? 0 : row_size;
}
// Perform the main processing of the batch
ProcessBatch(&value_row_sizes, callback);
Expand Down
Loading