Skip to content

Commit

Permalink
Compilers::ActiveRecordRelations: Improve ordering of the signatures
Browse files Browse the repository at this point in the history
- This swaps the order of the `T.untyped` and `T::Array[T.untyped]`
  signatures when generating RBI files for ActiveRecord methods.
- Some of the methods that define a sig for `T::Array[T.untyped]`, such
  as `new` [1], don't support arrays as input.
- More common than arrays are non-array inputs for ActiveRecord methods,
  so the non-array signature should come first.

[1] - Shopify#1981
  • Loading branch information
issyl0 committed Oct 30, 2024
1 parent 448e639 commit 9e20d27
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 30 deletions.
22 changes: 10 additions & 12 deletions lib/tapioca/dsl/compilers/active_record_relations.rb
Original file line number Diff line number Diff line change
Expand Up @@ -983,18 +983,17 @@ def create_common_methods
method.add_param("attributes")
method.add_block_param("block")

# `T.untyped` matches `T::Array[T.untyped]` so the array signature
# must be defined first for Sorbet to pick it, if valid.
# `T.untyped` matches `T::Array[T.untyped]` so the more common non-array signature must come first.
method.add_sig do |sig|
sig.add_param("attributes", "T::Array[T.untyped]")
sig.add_param("attributes", "T.untyped")
sig.add_param("block", "T.nilable(T.proc.params(object: #{constant_name}).void)")
sig.return_type = "T::Array[#{constant_name}]"
sig.return_type = constant_name
end

method.add_sig do |sig|
sig.add_param("attributes", "T.untyped")
sig.add_param("attributes", "T::Array[T.untyped]")
sig.add_param("block", "T.nilable(T.proc.params(object: #{constant_name}).void)")
sig.return_type = constant_name
sig.return_type = "T::Array[#{constant_name}]"
end
end
end
Expand All @@ -1009,18 +1008,17 @@ def create_common_methods
sig.return_type = constant_name
end

# `T.untyped` matches `T::Array[T.untyped]` so the array signature
# must be defined first for Sorbet to pick it, if valid.
# `T.untyped` matches `T::Array[T.untyped]` so the more common non-array signature must come first.
method.add_sig do |sig|
sig.add_param("attributes", "T::Array[T.untyped]")
sig.add_param("attributes", "T.untyped")
sig.add_param("block", "T.nilable(T.proc.params(object: #{constant_name}).void)")
sig.return_type = "T::Array[#{constant_name}]"
sig.return_type = constant_name
end

method.add_sig do |sig|
sig.add_param("attributes", "T.untyped")
sig.add_param("attributes", "T::Array[T.untyped]")
sig.add_param("block", "T.nilable(T.proc.params(object: #{constant_name}).void)")
sig.return_type = constant_name
sig.return_type = "T::Array[#{constant_name}]"
end
end
end
Expand Down
36 changes: 18 additions & 18 deletions spec/tapioca/dsl/compilers/active_record_relations_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def any?(&block); end
def average(column_name); end
sig { params(block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
def build(attributes = nil, &block); end
sig { params(operation: Symbol, column_name: T.any(String, Symbol)).returns(T.any(Integer, Float, BigDecimal)) }
Expand All @@ -115,21 +115,21 @@ def calculate(operation, column_name); end
def count(column_name = nil, &block); end
sig { params(block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
def create(attributes = nil, &block); end
sig { params(block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
def create!(attributes = nil, &block); end
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
def create_or_find_by(attributes, &block); end
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
def create_or_find_by!(attributes, &block); end
sig { returns(T::Array[::Post]) }
Expand Down Expand Up @@ -163,16 +163,16 @@ def find_each(start: nil, finish: nil, batch_size: 1000, error_on_ignore: nil, o
sig { params(start: T.untyped, finish: T.untyped, batch_size: Integer, error_on_ignore: T.untyped, order: Symbol).returns(T::Enumerator[T::Enumerator[::Post]]) }
def find_in_batches(start: nil, finish: nil, batch_size: 1000, error_on_ignore: nil, order: :asc, &block); end
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
def find_or_create_by(attributes, &block); end
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
def find_or_create_by!(attributes, &block); end
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
def find_or_initialize_by(attributes, &block); end
sig { params(signed_id: T.untyped, purpose: T.untyped).returns(T.nilable(::Post)) }
Expand Down Expand Up @@ -241,8 +241,8 @@ def member?(record); end
def minimum(column_name); end
sig { params(block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
def new(attributes = nil, &block); end
sig { params(block: T.nilable(T.proc.params(record: ::Post).returns(T.untyped))).returns(T::Boolean) }
Expand Down Expand Up @@ -814,8 +814,8 @@ def any?(&block); end
def average(column_name); end
sig { params(block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
def build(attributes = nil, &block); end
sig { params(operation: Symbol, column_name: T.any(String, Symbol)).returns(T.any(Integer, Float, BigDecimal)) }
Expand All @@ -826,21 +826,21 @@ def calculate(operation, column_name); end
def count(column_name = nil, &block); end
sig { params(block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
def create(attributes = nil, &block); end
sig { params(block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
def create!(attributes = nil, &block); end
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
def create_or_find_by(attributes, &block); end
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
def create_or_find_by!(attributes, &block); end
sig { returns(T::Array[::Post]) }
Expand Down Expand Up @@ -879,16 +879,16 @@ def find_each(start: nil, finish: nil, batch_size: 1000, error_on_ignore: nil, o
sig { params(start: T.untyped, finish: T.untyped, batch_size: Integer, error_on_ignore: T.untyped, order: Symbol).returns(T::Enumerator[T::Enumerator[::Post]]) }
def find_in_batches(start: nil, finish: nil, batch_size: 1000, error_on_ignore: nil, order: :asc, &block); end
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
def find_or_create_by(attributes, &block); end
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
def find_or_create_by!(attributes, &block); end
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
def find_or_initialize_by(attributes, &block); end
sig { params(signed_id: T.untyped, purpose: T.untyped).returns(T.nilable(::Post)) }
Expand Down Expand Up @@ -957,8 +957,8 @@ def member?(record); end
def minimum(column_name); end
sig { params(block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
def new(attributes = nil, &block); end
sig { params(block: T.nilable(T.proc.params(record: ::Post).returns(T.untyped))).returns(T::Boolean) }
Expand Down

0 comments on commit 9e20d27

Please sign in to comment.